diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a9d9f51 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,140 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: check-ast + exclude: ^(tests|samples)/ + - id: sort-simple-yaml + exclude: ^(tests|samples)/ + - id: check-yaml + exclude: | + (?x)^( + meta.yaml + | tests/ + | samples/ + )$ + - id: check-xml + exclude: ^(tests|samples)/ + - id: check-toml + exclude: ^(tests|samples)/ + - id: check-docstring-first + exclude: ^(tests|samples)/ + - id: check-json + exclude: ^(tests|samples)/ + - id: fix-encoding-pragma + exclude: ^(tests|samples)/ + - id: detect-private-key + exclude: ^(tests|samples)/ + - id: trailing-whitespace + exclude: ^(tests|samples)/ + - repo: https://github.com/asottile/add-trailing-comma + rev: v3.1.0 + hooks: + - id: add-trailing-comma + exclude: ^(tests|samples)/ + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.7.0 + hooks: + - id: mypy + exclude: + (?x)( + pb2\.py$ + | grpc\.py$ + | ^docs + | ^tests/ + | ^samples/ + | \.html$ + ) + args: [ + --ignore-missing-imports, + --disable-error-code=var-annotated, + --disable-error-code=union-attr, + --disable-error-code=assignment, + --disable-error-code=attr-defined, + --disable-error-code=import-untyped, + --disable-error-code=truthy-function, + --follow-imports=skip, + --explicit-package-bases, + ] + - repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black + args: [ --line-length=79 ] + exclude: ^(tests|samples)/ + - repo: https://github.com/PyCQA/flake8 + rev: 6.1.0 + hooks: + - id: flake8 + args: [ "--extend-ignore=E203"] + exclude: ^(tests|samples)/ + - repo: https://github.com/pylint-dev/pylint + rev: v3.0.2 + hooks: + - id: pylint + exclude: + (?x)( + ^docs + | ^tests/ + | ^samples/ + | pb2\.py$ + | grpc\.py$ + | \.demo$ + | \.md$ + | \.html$ + ) + args: [ + --disable=W0511, + --disable=W0718, + --disable=W0122, + --disable=C0103, + --disable=R0913, + --disable=E0401, + --disable=E1101, + --disable=C0415, + --disable=W0603, + --disable=R1705, + --disable=R0914, + --disable=E0601, + --disable=W0602, + --disable=W0604, + --disable=R0801, + --disable=R0902, + --disable=R0903, + --disable=C0123, + --disable=W0231, + --disable=W1113, + --disable=W0221, + --disable=R0401, + --disable=W0632, + --disable=W0123, + --disable=C3001, + --disable=W0201, + --disable=C0302, + --disable=W1203, + --disable=C2801, + --disable=C0114, # Disable missing module docstring for quick dev + --disable=C0115, # Disable missing class docstring for quick dev + --disable=C0116, # Disable missing function or method docstring for quick dev + ] + - repo: https://github.com/pre-commit/mirrors-eslint + rev: v7.32.0 + hooks: + - id: eslint + files: \.(js|jsx)$ + exclude: '(.*js_third_party.*|^tests/|^samples/)' + args: [ '--fix' ] + - repo: https://github.com/thibaudcolas/pre-commit-stylelint + rev: v14.4.0 + hooks: + - id: stylelint + files: \.(css)$ + exclude: '(.*css_third_party.*|^tests/|^samples/)' + args: [ '--fix' ] + - repo: https://github.com/pre-commit/mirrors-prettier + rev: 'v3.0.0' + hooks: + - id: prettier + additional_dependencies: [ 'prettier@3.0.0' ] + files: \.(tsx?)$ + exclude: ^(tests|samples)/ \ No newline at end of file diff --git a/dashscope/__init__.py b/dashscope/__init__.py index 45e5819..744b269 100644 --- a/dashscope/__init__.py +++ b/dashscope/__init__.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import logging @@ -7,7 +8,10 @@ from dashscope.aigc.conversation import Conversation, History, HistoryItem from dashscope.aigc.generation import AioGeneration, Generation from dashscope.aigc.image_synthesis import ImageSynthesis -from dashscope.aigc.multimodal_conversation import MultiModalConversation, AioMultiModalConversation +from dashscope.aigc.multimodal_conversation import ( + MultiModalConversation, + AioMultiModalConversation, +) from dashscope.aigc.video_synthesis import VideoSynthesis from dashscope.app.application import Application from dashscope.assistants import Assistant, AssistantList, Assistants @@ -15,81 +19,105 @@ from dashscope.audio.asr.transcription import Transcription from dashscope.audio.tts.speech_synthesizer import SpeechSynthesizer from dashscope.common.api_key import save_api_key -from dashscope.common.env import (api_key, api_key_file_path, - base_http_api_url, base_websocket_api_url) +from dashscope.common.env import ( + api_key, + api_key_file_path, + base_http_api_url, + base_websocket_api_url, +) from dashscope.customize.deployments import Deployments from dashscope.customize.finetunes import FineTunes from dashscope.embeddings.batch_text_embedding import BatchTextEmbedding -from dashscope.embeddings.batch_text_embedding_response import \ - BatchTextEmbeddingResponse +from dashscope.embeddings.batch_text_embedding_response import ( + BatchTextEmbeddingResponse, +) from dashscope.embeddings.multimodal_embedding import ( - MultiModalEmbedding, MultiModalEmbeddingItemAudio, - MultiModalEmbeddingItemImage, MultiModalEmbeddingItemText, AioMultiModalEmbedding) + MultiModalEmbedding, + MultiModalEmbeddingItemAudio, + MultiModalEmbeddingItemImage, + MultiModalEmbeddingItemText, + AioMultiModalEmbedding, +) from dashscope.embeddings.text_embedding import TextEmbedding from dashscope.files import Files from dashscope.models import Models from dashscope.nlp.understanding import Understanding from dashscope.rerank.text_rerank import TextReRank -from dashscope.threads import (MessageFile, Messages, Run, RunList, Runs, - RunStep, RunStepList, Steps, Thread, - ThreadMessage, ThreadMessageList, Threads) -from dashscope.tokenizers import (Tokenization, Tokenizer, get_tokenizer, - list_tokenizers) - -__all__ = [ - base_http_api_url, - base_websocket_api_url, - api_key, - api_key_file_path, - save_api_key, - AioGeneration, - Conversation, - Generation, - History, - HistoryItem, - ImageSynthesis, - Transcription, - Files, - Deployments, - FineTunes, - Models, - TextEmbedding, - MultiModalEmbedding, - AioMultiModalEmbedding, - MultiModalEmbeddingItemAudio, - MultiModalEmbeddingItemImage, - MultiModalEmbeddingItemText, - SpeechSynthesizer, - MultiModalConversation, - AioMultiModalConversation, - BatchTextEmbedding, - BatchTextEmbeddingResponse, - Understanding, - CodeGeneration, - Tokenization, - Tokenizer, - get_tokenizer, - list_tokenizers, - Application, - TextReRank, - Assistants, - Threads, +from dashscope.threads import ( + MessageFile, Messages, - Runs, - Assistant, - ThreadMessage, Run, - Steps, - AssistantList, - ThreadMessageList, RunList, + Runs, + RunStep, RunStepList, + Steps, Thread, - DeleteResponse, - RunStep, - MessageFile, - AssistantFile, - VideoSynthesis, + ThreadMessage, + ThreadMessageList, + Threads, +) +from dashscope.tokenizers import ( + Tokenization, + Tokenizer, + get_tokenizer, + list_tokenizers, +) + +__all__ = [ + "base_http_api_url", + "base_websocket_api_url", + "api_key", + "api_key_file_path", + "save_api_key", + "AioGeneration", + "Conversation", + "Generation", + "History", + "HistoryItem", + "ImageSynthesis", + "Transcription", + "Files", + "Deployments", + "FineTunes", + "Models", + "TextEmbedding", + "MultiModalEmbedding", + "AioMultiModalEmbedding", + "MultiModalEmbeddingItemAudio", + "MultiModalEmbeddingItemImage", + "MultiModalEmbeddingItemText", + "SpeechSynthesizer", + "MultiModalConversation", + "AioMultiModalConversation", + "BatchTextEmbedding", + "BatchTextEmbeddingResponse", + "Understanding", + "CodeGeneration", + "Tokenization", + "Tokenizer", + "get_tokenizer", + "list_tokenizers", + "Application", + "TextReRank", + "Assistants", + "Threads", + "Messages", + "Runs", + "Assistant", + "ThreadMessage", + "Run", + "Steps", + "AssistantList", + "ThreadMessageList", + "RunList", + "RunStepList", + "Thread", + "DeleteResponse", + "RunStep", + "MessageFile", + "AssistantFile", + "VideoSynthesis", ] logging.getLogger(__name__).addHandler(NullHandler()) diff --git a/dashscope/aigc/__init__.py b/dashscope/aigc/__init__.py index 0b2f4e7..6f1e32c 100644 --- a/dashscope/aigc/__init__.py +++ b/dashscope/aigc/__init__.py @@ -1,20 +1,24 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from .conversation import Conversation, History, HistoryItem from .generation import Generation, AioGeneration from .image_synthesis import ImageSynthesis, AioImageSynthesis -from .multimodal_conversation import MultiModalConversation, AioMultiModalConversation +from .multimodal_conversation import ( + MultiModalConversation, + AioMultiModalConversation, +) from .video_synthesis import VideoSynthesis, AioVideoSynthesis __all__ = [ - Generation, - AioGeneration, - Conversation, - HistoryItem, - History, - ImageSynthesis, - AioImageSynthesis, - MultiModalConversation, - AioMultiModalConversation, - VideoSynthesis, - AioVideoSynthesis, + "Generation", + "AioGeneration", + "Conversation", + "HistoryItem", + "History", + "ImageSynthesis", + "AioImageSynthesis", + "MultiModalConversation", + "AioMultiModalConversation", + "VideoSynthesis", + "AioVideoSynthesis", ] diff --git a/dashscope/aigc/chat_completion.py b/dashscope/aigc/chat_completion.py index 34c30f6..7df150c 100644 --- a/dashscope/aigc/chat_completion.py +++ b/dashscope/aigc/chat_completion.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -5,20 +6,23 @@ import dashscope from dashscope.aigc.generation import Generation -from dashscope.api_entities.chat_completion_types import (ChatCompletion, - ChatCompletionChunk) -from dashscope.api_entities.dashscope_response import (GenerationResponse, - Message) +from dashscope.api_entities.chat_completion_types import ( + ChatCompletion, + ChatCompletionChunk, +) +from dashscope.api_entities.dashscope_response import ( + GenerationResponse, + Message, +) from dashscope.client.base_api import BaseAioApi, CreateMixin from dashscope.common.error import InputRequired, ModelRequired from dashscope.common.utils import _get_task_group_and_task class Completions(CreateMixin): - """Support openai compatible chat completion interface. + """Support openai compatible chat completion interface.""" - """ - SUB_PATH = '' + SUB_PATH = "" @classmethod def create( @@ -37,7 +41,7 @@ def create( workspace: str = None, extra_headers: Dict = None, extra_body: Dict = None, - **kwargs + **kwargs, ) -> Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]]: """Call openai compatible chat completion model service. @@ -49,7 +53,7 @@ def create( 'content': 'The weather is fine today.'}, {'role': 'assistant', 'content': 'Suitable for outings'}] stream(bool, `optional`): Enable server-sent events - (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 + (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 # pylint: disable=line-too-long the result will back partially[qwen-turbo,bailian-v1]. temperature(float, `optional`): Used to control the degree of randomness and diversity. Specifically, the temperature @@ -67,23 +71,23 @@ def create( tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. - top_k(int, `optional`): The size of the sample candidate set when generated. # noqa E501 + top_k(int, `optional`): The size of the sample candidate set when generated. # noqa E501 # pylint: disable=line-too-long For example, when the value is 50, only the 50 highest-scoring tokens # noqa E501 in a single generation form a randomly sampled candidate set. # noqa E501 The larger the value, the higher the randomness generated; # noqa E501 the smaller the value, the higher the certainty generated. # noqa E501 The default value is 0, which means the top_k policy is # noqa E501 not enabled. At this time, only the top_p policy takes effect. # noqa E501 - stop(list[str] or list[list[int]], `optional`): Used to control the generation to stop # noqa E501 + stop(list[str] or list[list[int]], `optional`): Used to control the generation to stop # noqa E501 # pylint: disable=line-too-long when encountering setting str or token ids, the result will not include # noqa E501 stop words or tokens. - max_tokens(int, `optional`): The maximum token num expected to be output. It should be # noqa E501 - noted that the length generated by the model will only be less than max_tokens, # noqa E501 - not necessarily equal to it. If max_tokens is set too large, the service will # noqa E501 + max_tokens(int, `optional`): The maximum token num expected to be output. It should be # noqa E501 # pylint: disable=line-too-long + noted that the length generated by the model will only be less than max_tokens, # noqa E501 # pylint: disable=line-too-long + not necessarily equal to it. If max_tokens is set too large, the service will # noqa E501 # pylint: disable=line-too-long directly prompt that the length exceeds the limit. It is generally # noqa E501 not recommended to set this value. - repetition_penalty(float, `optional`): Used to control the repeatability when generating models. # noqa E501 - Increasing repetition_penalty can reduce the duplication of model generation. # noqa E501 + repetition_penalty(float, `optional`): Used to control the repeatability when generating models. # noqa E501 # pylint: disable=line-too-long + Increasing repetition_penalty can reduce the duplication of model generation. # noqa E501 # pylint: disable=line-too-long 1.0 means no punishment. api_key (str, optional): The api api_key, can be None, if None, will get by default rule. @@ -99,45 +103,51 @@ def create( stream is True, return Generator, otherwise ChatCompletion. """ if messages is None or not messages: - raise InputRequired('Messages is required!') + raise InputRequired("Messages is required!") if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") data = {} - data['model'] = model - data['messages'] = messages + data["model"] = model + data["messages"] = messages if temperature is not None: - data['temperature'] = temperature + data["temperature"] = temperature if top_p is not None: - data['top_p'] = top_p + data["top_p"] = top_p if top_k is not None: - data['top_k'] = top_k + data["top_k"] = top_k if stop is not None: - data['stop'] = stop + data["stop"] = stop if max_tokens is not None: - data[max_tokens] = max_tokens + data[max_tokens] = max_tokens # type: ignore[index] if repetition_penalty is not None: - data['repetition_penalty'] = repetition_penalty + data["repetition_penalty"] = repetition_penalty if extra_body is not None and extra_body: data = {**data, **extra_body} if extra_headers is not None and extra_headers: - kwargs = { - 'headers': extra_headers - } if kwargs else { - **kwargs, - **{ - 'headers': extra_headers + kwargs = ( + { + "headers": extra_headers, + } + if kwargs + else { + **kwargs, + **{ + "headers": extra_headers, + }, } - } + ) - response = super().call(data=data, - path='chat/completions', - base_address=dashscope.base_compatible_api_url, - api_key=api_key, - flattened_output=True, - stream=stream, - workspace=workspace, - **kwargs) + response = super().call( + data=data, + path="chat/completions", + base_address=dashscope.base_compatible_api_url, + api_key=api_key, + flattened_output=True, + stream=stream, + workspace=workspace, + **kwargs, + ) if stream: return (ChatCompletionChunk(**item) for _, item in response) else: @@ -145,24 +155,27 @@ def create( class AioGeneration(BaseAioApi): - task = 'text-generation' + task = "text-generation" """API for AI-Generated Content(AIGC) models. """ + class Models: """@deprecated, use qwen_turbo instead""" - qwen_v1 = 'qwen-v1' + + qwen_v1 = "qwen-v1" """@deprecated, use qwen_plus instead""" - qwen_plus_v1 = 'qwen-plus-v1' + qwen_plus_v1 = "qwen-plus-v1" - bailian_v1 = 'bailian-v1' - dolly_12b_v2 = 'dolly-12b-v2' - qwen_turbo = 'qwen-turbo' - qwen_plus = 'qwen-plus' - qwen_max = 'qwen-max' + bailian_v1 = "bailian-v1" + dolly_12b_v2 = "dolly-12b-v2" + qwen_turbo = "qwen-turbo" + qwen_plus = "qwen-plus" + qwen_max = "qwen-max" @classmethod - async def call( + # type: ignore[override] + async def call( # pylint: disable=arguments-renamed # type: ignore[override] # noqa: E501 cls, model: str, prompt: Any = None, @@ -171,7 +184,7 @@ async def call( messages: List[Message] = None, plugins: Union[str, Dict[str, Any]] = None, workspace: str = None, - **kwargs + **kwargs, ) -> Union[GenerationResponse, Generator[GenerationResponse, None, None]]: """Call generation model service. @@ -193,7 +206,7 @@ async def call( plugins (Any): The plugin config. Can be plugins config str, or dict. **kwargs: stream(bool, `optional`): Enable server-sent events - (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 + (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 # pylint: disable=line-too-long the result will back partially[qwen-turbo,bailian-v1]. temperature(float, `optional`): Used to control the degree of randomness and diversity. Specifically, the temperature @@ -211,8 +224,8 @@ async def call( tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered[qwen-turbo,bailian-v1]. - top_k(int, `optional`): The size of the sample candidate set when generated. # noqa E501 - For example, when the value is 50, only the 50 highest-scoring tokens # noqa E501 + top_k(int, `optional`): The size of the sample candidate set when generated. # noqa E501 # pylint: disable=line-too-long + For example, when the value is 50, only the 50 highest-scoring tokens # noqa E501 # pylint: disable=line-too-long in a single generation form a randomly sampled candidate set. # noqa E501 The larger the value, the higher the randomness generated; # noqa E501 the smaller the value, the higher the certainty generated. # noqa E501 @@ -227,20 +240,20 @@ async def call( large model product, support model: [bailian-v1]. result_format(str, `optional`): [message|text] Set result result format. # noqa E501 Default result is text - incremental_output(bool, `optional`): Used to control the streaming output mode. # noqa E501 - If true, the subsequent output will include the previously input content. # noqa E501 - Otherwise, the subsequent output will not include the previously output # noqa E501 + incremental_output(bool, `optional`): Used to control the streaming output mode. # noqa E501 # pylint: disable=line-too-long + If true, the subsequent output will include the previously input content. # noqa E501 # pylint: disable=line-too-long + Otherwise, the subsequent output will not include the previously output # noqa E501 # pylint: disable=line-too-long content. Default false. - stop(list[str] or list[list[int]], `optional`): Used to control the generation to stop # noqa E501 - when encountering setting str or token ids, the result will not include # noqa E501 + stop(list[str] or list[list[int]], `optional`): Used to control the generation to stop # noqa E501 # pylint: disable=line-too-long + when encountering setting str or token ids, the result will not include # noqa E501 # pylint: disable=line-too-long stop words or tokens. - max_tokens(int, `optional`): The maximum token num expected to be output. It should be # noqa E501 - noted that the length generated by the model will only be less than max_tokens, # noqa E501 - not necessarily equal to it. If max_tokens is set too large, the service will # noqa E501 + max_tokens(int, `optional`): The maximum token num expected to be output. It should be # noqa E501 # pylint: disable=line-too-long + noted that the length generated by the model will only be less than max_tokens, # noqa E501 # pylint: disable=line-too-long + not necessarily equal to it. If max_tokens is set too large, the service will # noqa E501 # pylint: disable=line-too-long directly prompt that the length exceeds the limit. It is generally # noqa E501 not recommended to set this value. - repetition_penalty(float, `optional`): Used to control the repeatability when generating models. # noqa E501 - Increasing repetition_penalty can reduce the duplication of model generation. # noqa E501 + repetition_penalty(float, `optional`): Used to control the repeatability when generating models. # noqa E501 # pylint: disable=line-too-long + Increasing repetition_penalty can reduce the duplication of model generation. # noqa E501 # pylint: disable=line-too-long 1.0 means no punishment. workspace (str): The dashscope workspace id. Raises: @@ -251,32 +264,46 @@ async def call( Generator[GenerationResponse, None, None]]: If stream is True, return Generator, otherwise GenerationResponse. """ - if (prompt is None or not prompt) and (messages is None - or not messages): - raise InputRequired('prompt or messages is required!') + if (prompt is None or not prompt) and ( + messages is None or not messages + ): + raise InputRequired("prompt or messages is required!") if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") task_group, function = _get_task_group_and_task(__name__) if plugins is not None: - headers = kwargs.pop('headers', {}) + headers = kwargs.pop("headers", {}) if isinstance(plugins, str): - headers['X-DashScope-Plugin'] = plugins + headers["X-DashScope-Plugin"] = plugins else: - headers['X-DashScope-Plugin'] = json.dumps(plugins) - kwargs['headers'] = headers - input, parameters = Generation._build_input_parameters( - model, prompt, history, messages, **kwargs) - response = await super().call(model=model, - task_group=task_group, - task=Generation.task, - function=function, - api_key=api_key, - input=input, - workspace=workspace, - **parameters) - is_stream = kwargs.get('stream', False) + headers["X-DashScope-Plugin"] = json.dumps(plugins) + kwargs["headers"] = headers + # pylint: disable=protected-access + ( + input, # pylint: disable=redefined-builtin + parameters, + ) = Generation._build_input_parameters( + model, + prompt, + history, + messages, + **kwargs, + ) + response = await super().call( + model=model, + task_group=task_group, + task=Generation.task, + function=function, + api_key=api_key, + input=input, + workspace=workspace, + **parameters, + ) + is_stream = kwargs.get("stream", False) if is_stream: - return (GenerationResponse.from_api_response(rsp) - async for rsp in response) + return ( # type: ignore[return-value] + GenerationResponse.from_api_response(rsp) + async for rsp in response + ) else: return GenerationResponse.from_api_response(response) diff --git a/dashscope/aigc/code_generation.py b/dashscope/aigc/code_generation.py index 6d8ef0c..a5f819f 100644 --- a/dashscope/aigc/code_generation.py +++ b/dashscope/aigc/code_generation.py @@ -1,9 +1,13 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Generator, List, Union -from dashscope.api_entities.dashscope_response import (DashScopeAPIResponse, - DictMixin, Role) +from dashscope.api_entities.dashscope_response import ( + DashScopeAPIResponse, + DictMixin, + Role, +) from dashscope.client.base_api import BaseApi from dashscope.common.constants import MESSAGE, SCENE from dashscope.common.error import InputRequired, ModelRequired @@ -46,34 +50,41 @@ def __init__(self, role: str, meta: dict, **kwargs): class CodeGeneration(BaseApi): - function = 'generation' + function = "generation" """API for AI-Generated Content(AIGC) models. """ + class Models: - tongyi_lingma_v1 = 'tongyi-lingma-v1' + tongyi_lingma_v1 = "tongyi-lingma-v1" class Scenes: - custom = 'custom' - nl2code = 'nl2code' - code2comment = 'code2comment' - code2explain = 'code2explain' - commit2msg = 'commit2msg' - unit_test = 'unittest' - code_qa = 'codeqa' - nl2sql = 'nl2sql' + custom = "custom" + nl2code = "nl2code" + code2comment = "code2comment" + code2explain = "code2explain" + commit2msg = "commit2msg" + unit_test = "unittest" + code_qa = "codeqa" + nl2sql = "nl2sql" @classmethod - def call( + def call( # type: ignore[override] cls, model: str, scene: str = None, api_key: str = None, message: List[MessageParam] = None, workspace: str = None, - **kwargs - ) -> Union[DashScopeAPIResponse, Generator[DashScopeAPIResponse, None, - None]]: + **kwargs, + ) -> Union[ + DashScopeAPIResponse, + Generator[ + DashScopeAPIResponse, + None, + None, + ], + ]: """Call generation model service. Args: @@ -92,24 +103,27 @@ def call( if None, will get by default rule(TODO: api key doc). message (list): The generation messages. scene == custom, examples: - [{"role": "user", "content": "根据下面的功能描述生成一个python函数。代码的功能是计算给定路径下所有文件的总大小。"}] # noqa E501 + [{"role": "user", "content": "根据下面的功能描述生成一个python函数。代码的功能是计算给定路径下所有文件的总大小。"}] # noqa E501 # pylint: disable=line-too-long scene == nl2code, examples: - [{"role": "user", "content": "计算给定路径下所有文件的总大小"}, {"role": "attachment", "meta": {"language": "java"}}] # noqa E501 + [{"role": "user", "content": "计算给定路径下所有文件的总大小"}, {"role": "attachment", "meta": {"language": "java"}}] # noqa E501 # pylint: disable=line-too-long scene == code2comment, examples: - [{"role": "user", "content": "1. 生成中文注释\n2. 仅生成代码部分,不需要额外解释函数功能\n"}, {"role": "attachment", "meta": {"code": "\t\t@Override\n\t\tpublic CancelExportTaskResponse cancelExportTask(\n\t\t\t\tCancelExportTask cancelExportTask) {\n\t\t\tAmazonEC2SkeletonInterface ec2Service = ServiceProvider.getInstance().getServiceImpl(AmazonEC2SkeletonInterface.class);\n\t\t\treturn ec2Service.cancelExportTask(cancelExportTask);\n\t\t}", "language": "java"}}] # noqa E501 + [{"role": "user", "content": "1. 生成中文注释\n2. 仅生成代码部分,不需要额外解释函数功能\n"}, {"role": "attachment", "meta": {"code": "\t\t@Override\n\t\tpublic CancelExportTaskResponse cancelExportTask(\n\t\t\t\tCancelExportTask cancelExportTask) {\n\t\t\tAmazonEC2SkeletonInterface ec2Service = ServiceProvider.getInstance().getServiceImpl(AmazonEC2SkeletonInterface.class);\n\t\t\treturn ec2Service.cancelExportTask(cancelExportTask);\n\t\t}", "language": "java"}}] # noqa E501 # pylint: disable=line-too-long scene == code2explain, examples: - [{"role": "user", "content": "要求不低于200字"}, {"role": "attachment", "meta": {"code": "@Override\n public int getHeaderCacheSize()\n {\n return 0;\n }\n\n", "language": "java"}}] # noqa E501 + [{"role": "user", "content": "要求不低于200字"}, {"role": "attachment", "meta": {"code": "@Override\n public int getHeaderCacheSize()\n {\n return 0;\n }\n\n", "language": "java"}}] # noqa E501 # pylint: disable=line-too-long scene == commit2msg, examples: - [{"role": "attachment", "meta": {"diff_list": [{"diff": "--- src/com/siondream/core/PlatformResolver.java\n+++ src/com/siondream/core/PlatformResolver.java\n@@ -1,11 +1,8 @@\npackage com.siondream.core;\n-\n-import com.badlogic.gdx.files.FileHandle;\n\npublic interface PlatformResolver {\npublic void openURL(String url);\npublic void rateApp();\npublic void sendFeedback();\n-\tpublic FileHandle[] listFolder(String path);\n}\n", "old_file_path": "src/com/siondream/core/PlatformResolver.java", "new_file_path": "src/com/siondream/core/PlatformResolver.java"}]}}] # noqa E501 + [{"role": "attachment", "meta": {"diff_list": [{"diff": "--- src/com/siondream/core/PlatformResolver.java\n+++ src/com/siondream/core/PlatformResolver.java\n@@ -1,11 +1,8 @@\npackage com.siondream.core;\n-\n-import com.badlogic.gdx.files.FileHandle;\n\npublic interface PlatformResolver {\npublic void openURL(String url);\npublic void rateApp();\npublic void sendFeedback();\n-\tpublic FileHandle[] listFolder(String path);\n}\n", "old_file_path": "src/com/siondream/core/PlatformResolver.java", "new_file_path": "src/com/siondream/core/PlatformResolver.java"}]}}] # noqa E501 # pylint: disable=line-too-long scene == unittest, examples: - [{"role": "attachment", "meta": {"code": "public static TimestampMap parseTimestampMap(Class typeClass, String input, DateTimeZone timeZone) throws IllegalArgumentException {\n if (typeClass == null) {\n throw new IllegalArgumentException(\"typeClass required\");\n }\n\n if (input == null) {\n return null;\n }\n\n TimestampMap result;\n\n typeClass = AttributeUtils.getStandardizedType(typeClass);\n if (typeClass.equals(String.class)) {\n result = new TimestampStringMap();\n } else if (typeClass.equals(Byte.class)) {\n result = new TimestampByteMap();\n } else if (typeClass.equals(Short.class)) {\n result = new TimestampShortMap();\n } else if (typeClass.equals(Integer.class)) {\n result = new TimestampIntegerMap();\n } else if (typeClass.equals(Long.class)) {\n result = new TimestampLongMap();\n } else if (typeClass.equals(Float.class)) {\n result = new TimestampFloatMap();\n } else if (typeClass.equals(Double.class)) {\n result = new TimestampDoubleMap();\n } else if (typeClass.equals(Boolean.class)) {\n result = new TimestampBooleanMap();\n } else if (typeClass.equals(Character.class)) {\n result = new TimestampCharMap();\n } else {\n throw new IllegalArgumentException(\"Unsupported type \" + typeClass.getClass().getCanonicalName());\n }\n\n if (input.equalsIgnoreCase(EMPTY_VALUE)) {\n return result;\n }\n\n StringReader reader = new StringReader(input + ' ');// Add 1 space so\n // reader.skip\n // function always\n // works when\n // necessary (end of\n // string not\n // reached).\n\n try {\n int r;\n char c;\n while ((r = reader.read()) != -1) {\n c = (char) r;\n switch (c) {\n case LEFT_BOUND_SQUARE_BRACKET:\n case LEFT_BOUND_BRACKET:\n parseTimestampAndValue(typeClass, reader, result, timeZone);\n break;\n default:\n // Ignore other chars outside of bounds\n }\n }\n } catch (IOException ex) {\n throw new RuntimeException(\"Unexpected expection while parsing timestamps\", ex);\n }\n\n return result;\n }", "language": "java"}}] # noqa E501 + [{"role": "attachment", "meta": {"code": "public static TimestampMap parseTimestampMap(Class typeClass, String input, DateTimeZone timeZone) throws IllegalArgumentException {\n if (typeClass == null) {\n throw new IllegalArgumentException(\"typeClass required\");\n }\n\n if (input == null) {\n return null;\n }\n\n TimestampMap result;\n\n typeClass = AttributeUtils.getStandardizedType(typeClass);\n if (typeClass.equals(String.class)) {\n result = new TimestampStringMap();\n } else if (typeClass.equals(Byte.class)) {\n result = new TimestampByteMap();\n } else if (typeClass.equals(Short.class)) {\n result = new TimestampShortMap();\n } else if (typeClass.equals(Integer.class)) {\n result = new TimestampIntegerMap();\n } else if (typeClass.equals(Long.class)) {\n result = new TimestampLongMap();\n } else if (typeClass.equals(Float.class)) {\n result = new TimestampFloatMap();\n } else if (typeClass.equals(Double.class)) {\n result = new TimestampDoubleMap();\n } else if (typeClass.equals(Boolean.class)) {\n result = new TimestampBooleanMap();\n } else if (typeClass.equals(Character.class)) {\n result = new TimestampCharMap();\n } else {\n throw new IllegalArgumentException(\"Unsupported type \" + typeClass.getClass().getCanonicalName());\n }\n\n if (input.equalsIgnoreCase(EMPTY_VALUE)) {\n return result;\n }\n\n StringReader reader = new StringReader(input + ' ');// Add 1 space so\n // reader.skip\n // function always\n // works when\n // necessary (end of\n // string not\n // reached).\n\n try {\n int r;\n char c;\n while ((r = reader.read()) != -1) {\n c = (char) r;\n switch (c) {\n case LEFT_BOUND_SQUARE_BRACKET:\n case LEFT_BOUND_BRACKET:\n parseTimestampAndValue(typeClass, reader, result, timeZone);\n break;\n default:\n // Ignore other chars outside of bounds\n }\n }\n } catch (IOException ex) {\n throw new RuntimeException(\"Unexpected expection while parsing timestamps\", ex);\n }\n\n return result;\n }", "language": "java"}}] # noqa E501 # pylint: disable=line-too-long scene == codeqa, examples: - [{"role": "user", "content": "I'm writing a small web server in Python, using BaseHTTPServer and a custom subclass of BaseHTTPServer.BaseHTTPRequestHandler. Is it possible to make this listen on more than one port?\nWhat I'm doing now:\nclass MyRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):\n def doGET\n [...]\n\nclass ThreadingHTTPServer(ThreadingMixIn, HTTPServer): \n pass\n\nserver = ThreadingHTTPServer(('localhost', 80), MyRequestHandler)\nserver.serve_forever()"}] # noqa E501 + [{"role": "user", "content": "I'm writing a small web server in Python, using BaseHTTPServer and a custom subclass of BaseHTTPServer.BaseHTTPRequestHandler. Is it possible to make this listen on more than one port?\nWhat I'm doing now:\nclass MyRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):\n def doGET\n [...]\n\nclass ThreadingHTTPServer(ThreadingMixIn, HTTPServer): \n pass\n\nserver = ThreadingHTTPServer(('localhost', 80), MyRequestHandler)\nserver.serve_forever()"}] # noqa E501 # pylint: disable=line-too-long scene == nl2sql, examples: - [{"role": "user", "content": "小明的总分数是多少"}, {"role": "attachment", "meta": {"synonym_infos": {"学生姓名": "姓名|名字|名称", "学生分数": "分数|得分"}, "recall_infos": [{"content": "student_score.id='小明'", "score": "0.83"}], "schema_infos": [{"table_id": "student_score", "table_desc": "学生分数表", "columns": [{"col_name": "id", "col_caption": "学生id", "col_desc": "例值为:1,2,3", "col_type": "string"}, {"col_name": "name", "col_caption": "学生姓名", "col_desc": "例值为:张三,李四,小明", "col_type": "string"}, {"col_name": "score", "col_caption": "学生分数", "col_desc": "例值为:98,100,66", "col_type": "string"}]}]}}] # noqa E501 + [{"role": "user", "content": "小明的总分数是多少"}, {"role": "attachment", "meta": {"synonym_infos": {"学生姓名": "姓名|名字|名称", "学生分数": "分数|得分"}, "recall_infos": [{"content": "student_score.id='小明'", "score": "0.83"}], "schema_infos": [{"table_id": "student_score", "table_desc": "学生分数表", "columns": [{"col_name": "id", "col_caption": "学生id", "col_desc": "例值为:1,2,3", "col_type": "string"}, {"col_name": "name", "col_caption": "学生姓名", "col_desc": "例值为:张三,李四,小明", "col_type": "string"}, {"col_name": "score", "col_caption": "学生分数", "col_desc": "例值为:98,100,66", "col_type": "string"}]}]}}] # noqa E501 # pylint: disable=line-too-long workspace (str): The dashscope workspace id. **kwargs: - n(int, `optional`): The number of output results, currently only supports 1, with a default value of 1 # noqa E501 + n( + int, + `optional` + ): The number of output results, currently only supports 1, with a default value of 1 # noqa E501 Returns: Union[DashScopeAPIResponse, @@ -117,29 +131,47 @@ def call( stream is True, return Generator, otherwise DashScopeAPIResponse. """ if (scene is None or not scene) or (message is None or not message): - raise InputRequired('scene and message is required!') + raise InputRequired("scene and message is required!") if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") task_group, task = _get_task_group_and_task(__name__) - input, parameters = cls._build_input_parameters( - model, scene, message, **kwargs) - response = super().call(model=model, - task_group=task_group, - task=task, - function=CodeGeneration.function, - api_key=api_key, - input=input, - workspace=workspace, - **parameters) - - is_stream = kwargs.get('stream', False) + ( + input, # pylint: disable=redefined-builtin + parameters, + ) = cls._build_input_parameters( + model, + scene, + message, + **kwargs, + ) + response = super().call( + model=model, + task_group=task_group, + task=task, + function=CodeGeneration.function, + api_key=api_key, + input=input, + workspace=workspace, + **parameters, + ) + + is_stream = kwargs.get("stream", False) if is_stream: return (rsp for rsp in response) else: return response @classmethod - def _build_input_parameters(cls, model, scene, message, **kwargs): - parameters = {'n': kwargs.pop('n', 1)} - input = {SCENE: scene, MESSAGE: message} + def _build_input_parameters( + cls, + model, + scene, + message, + **kwargs, + ): # pylint: disable=unused-argument + parameters = {"n": kwargs.pop("n", 1)} + input = { # pylint: disable=redefined-builtin + SCENE: scene, + MESSAGE: message, + } return input, {**parameters, **kwargs} diff --git a/dashscope/aigc/conversation.py b/dashscope/aigc/conversation.py index 8c26a36..bb7c455 100644 --- a/dashscope/aigc/conversation.py +++ b/dashscope/aigc/conversation.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -5,8 +6,11 @@ from http import HTTPStatus from typing import Any, Dict, Generator, List, Union -from dashscope.api_entities.dashscope_response import (ConversationResponse, - Message, Role) +from dashscope.api_entities.dashscope_response import ( + ConversationResponse, + Message, + Role, +) from dashscope.client.base_api import BaseApi from dashscope.common.constants import DEPRECATED_MESSAGE, HISTORY, PROMPT from dashscope.common.error import InputRequired, InvalidInput, ModelRequired @@ -15,9 +19,8 @@ class HistoryItem(dict): - """A conversation history item. + """A conversation history item.""" - """ def __init__(self, role: str, text: str = None, **kwargs): """Init a history item. @@ -33,7 +36,7 @@ def __init__(self, role: str, text: str = None, **kwargs): self.role = role dict.__init__(self, {role: []}) if text is not None: - self[self.role].append({'text': text}) + self[self.role].append({"text": text}) for k, v in kwargs.items(): self[self.role].append({k: v}) @@ -48,9 +51,8 @@ def add(self, key: str, content: Any): class History(list): - """Manage the conversation history. + """Manage the conversation history.""" - """ def __init__(self, items: List[HistoryItem] = None): """Init a history with list of HistoryItems. @@ -66,43 +68,55 @@ def __init__(self, items: List[HistoryItem] = None): def _history_to_qwen_format(history: History, n_history: int): - """Convert history to simple format. - [{"user":"您好", "bot":"我是你的助手,很高兴为您服务"}, - {"user":"user input", "bot":"bot output"}] + """Convert history to Qwen-compatible simple format. + + Transforms conversation history into a simplified format where each + entry contains user and bot messages as key-value pairs. + + Args: + history: The conversation history to convert. + n_history: Number of recent history items to include. + -1 means include all history. + + Returns: + list: Simplified history in format: + [{"user":"您好", "bot":"我是你的助手,很高兴为您服务"}, + {"user":"user input", "bot":"bot output"}] """ simple_history = [] user = None bot = None if n_history != -1 and len(history) >= 2 * n_history: - history = history[len(history) - 2 * n_history:] + history = history[len(history) - 2 * n_history :] for item in history: - if 'user' in item: - user = item['user'][0]['text'] - if 'bot' in item: - bot = item['bot'][0]['text'] + if "user" in item: + user = item["user"][0]["text"] + if "bot" in item: + bot = item["bot"][0]["text"] if user is not None and bot is not None: - simple_history.append({'user': user, 'bot': bot}) + simple_history.append({"user": user, "bot": bot}) user = None bot = None return simple_history class Conversation(BaseApi): - """Conversational robot interface. - """ - task = 'generation' + """Conversational robot interface.""" + + task = "generation" class Models: """@deprecated, use qwen_turbo instead""" - qwen_v1 = 'qwen-v1' + + qwen_v1 = "qwen-v1" """@deprecated, use qwen_plus instead""" - qwen_plus_v1 = 'qwen-plus-v1' + qwen_plus_v1 = "qwen-plus-v1" - qwen_turbo = 'qwen-turbo' - qwen_plus = 'qwen-plus' - qwen_max = 'qwen-max' + qwen_turbo = "qwen-turbo" + qwen_plus = "qwen-plus" + qwen_max = "qwen-max" def __init__(self, history: History = None) -> None: """Init a chat. @@ -117,9 +131,10 @@ def __init__(self, history: History = None) -> None: self.history = History() else: logger.warning(DEPRECATED_MESSAGE) - self.history = history + self.history = history # type: ignore[has-type] - def call( + # pylint: disable=arguments-renamed + def call( # type: ignore[override] self, model: str, prompt: Any = None, @@ -130,11 +145,25 @@ def call( messages: List[Message] = None, plugins: Union[str, Dict[str, Any]] = None, workspace: str = None, - **kwargs - ) -> Union[ConversationResponse, Generator[ConversationResponse, None, - None]]: + **kwargs, + ) -> Union[ + ConversationResponse, + Generator[ + ConversationResponse, + None, + None, + ], + ]: """Call conversational robot generator a response. + Note: This method overrides BaseApi.call() as an instance method + instead of a classmethod because Conversation maintains instance + state (self.history). Pylint's arguments-renamed warning is + disabled because the first parameter changes from 'cls' (in the + classmethod) to 'self' (in the instance method). The type + checker is instructed to ignore the signature incompatibility + via type: ignore[override]. + Args: model (str): The request model. prompt (Any): The input prompt. @@ -157,7 +186,7 @@ def call( plugins (Any): The plugin config, Can be plugins config str, or dict. **kwargs(qwen-turbo, qwen-plus): stream(bool, `optional`): Enable server-sent events - (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 + (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 # pylint: disable=line-too-long the result will back partially. temperature(float, `optional`): Used to control the degree of randomness and diversity. Specifically, the temperature @@ -175,8 +204,8 @@ def call( tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. - top_k(int, `optional`): The size of the sample candidate set when generated. # noqa E501 - For example, when the value is 50, only the 50 highest-scoring tokens # noqa E501 + top_k(int, `optional`): The size of the sample candidate set when generated. # noqa E501 # pylint: disable=line-too-long + For example, when the value is 50, only the 50 highest-scoring tokens # noqa E501 # pylint: disable=line-too-long in a single generation form a randomly sampled candidate set. # noqa E501 The larger the value, the higher the randomness generated; # noqa E501 the smaller the value, the higher the certainty generated. # noqa E501 @@ -191,20 +220,20 @@ def call( large model product, support model: [bailian-v1]. result_format(str, `optional`): [message|text] Set result result format. # noqa E501 Default result is text - incremental_output(bool, `optional`): Used to control the streaming output mode. # noqa E501 - If true, the subsequent output will include the previously input content. # noqa E501 - Otherwise, the subsequent output will not include the previously output # noqa E501 + incremental_output(bool, `optional`): Used to control the streaming output mode. # noqa E501 # pylint: disable=line-too-long + If true, the subsequent output will include the previously input content. # noqa E501 # pylint: disable=line-too-long + Otherwise, the subsequent output will not include the previously output # noqa E501 # pylint: disable=line-too-long content. Default false. - stop(list[str] or list[list[int]], `optional`): Used to control the generation to stop # noqa E501 - when encountering setting str or token ids, the result will not include # noqa E501 + stop(list[str] or list[list[int]], `optional`): Used to control the generation to stop # noqa E501 # pylint: disable=line-too-long + when encountering setting str or token ids, the result will not include # noqa E501 # pylint: disable=line-too-long stop words or tokens. - max_tokens(int, `optional`): The maximum token num expected to be output. It should be # noqa E501 - noted that the length generated by the model will only be less than max_tokens, # noqa E501 - not necessarily equal to it. If max_tokens is set too large, the service will # noqa E501 + max_tokens(int, `optional`): The maximum token num expected to be output. It should be # noqa E501 # pylint: disable=line-too-long + noted that the length generated by the model will only be less than max_tokens, # noqa E501 # pylint: disable=line-too-long + not necessarily equal to it. If max_tokens is set too large, the service will # noqa E501 # pylint: disable=line-too-long directly prompt that the length exceeds the limit. It is generally # noqa E501 not recommended to set this value. - repetition_penalty(float, `optional`): Used to control the repeatability when generating models. # noqa E501 - Increasing repetition_penalty can reduce the duplication of model generation. # noqa E501 + repetition_penalty(float, `optional`): Used to control the repeatability when generating models. # noqa E501 # pylint: disable=line-too-long + Increasing repetition_penalty can reduce the duplication of model generation. # noqa E501 # pylint: disable=line-too-long 1.0 means no punishment. workspace (str): The dashscope workspace id. Raises: @@ -217,98 +246,165 @@ def call( stream is True, return Generator, otherwise ConversationResponse. """ - if ((prompt is None or not prompt) - and ((messages is None or not messages))): - raise InputRequired('prompt or messages is required!') + if (prompt is None or not prompt) and ( + (messages is None or not messages) + ): + raise InputRequired("prompt or messages is required!") if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") task_group, _ = _get_task_group_and_task(__name__) if plugins is not None: - headers = kwargs.pop('headers', {}) + headers = kwargs.pop("headers", {}) if isinstance(plugins, str): - headers['X-DashScope-Plugin'] = plugins + headers["X-DashScope-Plugin"] = plugins else: - headers['X-DashScope-Plugin'] = json.dumps(plugins) - kwargs['headers'] = headers - input, parameters = self._build_input_parameters( - model, prompt, history, auto_history, n_history, messages, - **kwargs) - response = super().call(model=model, - task_group=task_group, - task='text-generation', - function='generation', - api_key=api_key, - input=input, - workspace=workspace, - **parameters) - is_stream = kwargs.get('stream', False) + headers["X-DashScope-Plugin"] = json.dumps(plugins) + kwargs["headers"] = headers + ( + input, # pylint: disable=redefined-builtin + parameters, + ) = self._build_input_parameters( + model, + prompt, + history, + auto_history, + n_history, + messages, + **kwargs, + ) + response = super().call( + model=model, + task_group=task_group, + task="text-generation", + function="generation", + api_key=api_key, + input=input, + workspace=workspace, + **parameters, + ) + is_stream = kwargs.get("stream", False) return self._handle_response(prompt, response, is_stream) def _handle_stream_response(self, prompt, responses): + """Handle streaming response and update conversation history. + + Args: + prompt: The user's input prompt. + responses: Generator yielding response chunks. + + Yields: + ConversationResponse: Parsed response objects. + """ for rsp in responses: rsp = ConversationResponse.from_api_response(rsp) yield rsp - if rsp.status_code == HTTPStatus.OK and rsp.output.choices is None: - user_item = HistoryItem('user', text=prompt) - bot_history_item = HistoryItem('bot', text=rsp.output.text) + if ( + # pylint: disable=undefined-loop-variable + rsp.status_code == HTTPStatus.OK + and rsp.output.choices # pylint: disable=undefined-loop-variable + is None # pylint: disable=undefined-loop-variable + ): # pylint: disable=undefined-loop-variable + user_item = HistoryItem("user", text=prompt) + bot_history_item = HistoryItem("bot", text=rsp.output.text) self.history.append(user_item) self.history.append(bot_history_item) def _handle_response(self, prompt, response, is_stream): + """Handle API response and update conversation history. + + Args: + prompt: The user's input prompt. + response: The API response or response generator. + is_stream: Whether the response is streaming. + + Returns: + ConversationResponse or Generator: Parsed response. + """ if is_stream: - return (rsp - for rsp in self._handle_stream_response(prompt, response)) + return ( + rsp for rsp in self._handle_stream_response(prompt, response) + ) else: response = ConversationResponse.from_api_response(response) - if (response.status_code == HTTPStatus.OK - and response.output.choices is None): - user_item = HistoryItem('user', text=prompt) - bot_history_item = HistoryItem('bot', - text=response.output['text']) + if ( + response.status_code == HTTPStatus.OK + and response.output.choices is None + ): + user_item = HistoryItem("user", text=prompt) + bot_history_item = HistoryItem( + "bot", + text=response.output["text"], + ) self.history.append(user_item) self.history.append(bot_history_item) return response - def _build_input_parameters(self, model, prompt, history, auto_history, - n_history, messages, **kwargs): + def _build_input_parameters( + self, + model, + prompt, + history, + auto_history, + n_history, + messages, + **kwargs, + ): + """Build input data and parameters for API call. + + Args: + model: The model name. + prompt: The user's input prompt. + history: User-provided conversation history. + auto_history: Whether to use automatic history management. + n_history: Number of history items to include. + messages: List of message objects. + **kwargs: Additional parameters. + + Returns: + tuple: (input, parameters) for API call. + """ if model == Conversation.Models.qwen_v1: logger.warning( - 'Model %s is deprecated, use %s instead!' % - (Conversation.Models.qwen_v1, Conversation.Models.qwen_turbo)) + "Model %s is deprecated, use %s instead!", + Conversation.Models.qwen_v1, + Conversation.Models.qwen_turbo, + ) if model == Conversation.Models.qwen_plus_v1: - logger.warning('Model %s is deprecated, use %s instead!' % - (Conversation.Models.qwen_plus_v1, - Conversation.Models.qwen_plus)) + logger.warning( + "Model %s is deprecated, use %s instead!", + Conversation.Models.qwen_plus_v1, + Conversation.Models.qwen_plus, + ) parameters = {} if history is not None and auto_history: - raise InvalidInput('auto_history is True, history must None') + raise InvalidInput("auto_history is True, history must None") if history is not None: # use user provided history or system. logger.warning(DEPRECATED_MESSAGE) - input = { - PROMPT: - prompt, - HISTORY: - _history_to_qwen_format(history, n_history) if history else [], + input = { # pylint: disable=redefined-builtin + PROMPT: prompt, + HISTORY: _history_to_qwen_format(history, n_history) + if history + else [], } elif auto_history: logger.warning(DEPRECATED_MESSAGE) input = { PROMPT: prompt, - HISTORY: _history_to_qwen_format(self.history, n_history) + HISTORY: _history_to_qwen_format(self.history, n_history), } elif messages: msgs = deepcopy(messages) if prompt is not None and prompt: - msgs.append({'role': Role.USER, 'content': prompt}) - input = {'messages': msgs} + msgs.append({"role": Role.USER, "content": prompt}) + input = {"messages": msgs} else: input = { PROMPT: prompt, } # parameters - if model.startswith('qwen'): - enable_search = kwargs.pop('enable_search', False) + if model.startswith("qwen"): + enable_search = kwargs.pop("enable_search", False) if enable_search: - parameters['enable_search'] = enable_search + parameters["enable_search"] = enable_search return input, {**parameters, **kwargs} diff --git a/dashscope/aigc/generation.py b/dashscope/aigc/generation.py index 54f1792..5501c83 100644 --- a/dashscope/aigc/generation.py +++ b/dashscope/aigc/generation.py @@ -1,15 +1,23 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import copy import json from typing import Any, Dict, Generator, List, Union, AsyncGenerator -from dashscope.api_entities.dashscope_response import (GenerationResponse, - Message, Role) +from dashscope.api_entities.dashscope_response import ( + GenerationResponse, + Message, + Role, +) from dashscope.client.base_api import BaseAioApi, BaseApi -from dashscope.common.constants import (CUSTOMIZED_MODEL_ID, - DEPRECATED_MESSAGE, HISTORY, MESSAGES, - PROMPT) +from dashscope.common.constants import ( + CUSTOMIZED_MODEL_ID, + DEPRECATED_MESSAGE, + HISTORY, + MESSAGES, + PROMPT, +) from dashscope.common.error import InputRequired, ModelRequired from dashscope.common.logging import logger from dashscope.common.utils import _get_task_group_and_task @@ -18,24 +26,27 @@ class Generation(BaseApi): - task = 'text-generation' + task = "text-generation" """API for AI-Generated Content(AIGC) models. """ + class Models: """@deprecated, use qwen_turbo instead""" - qwen_v1 = 'qwen-v1' + + qwen_v1 = "qwen-v1" """@deprecated, use qwen_plus instead""" - qwen_plus_v1 = 'qwen-plus-v1' + qwen_plus_v1 = "qwen-plus-v1" - bailian_v1 = 'bailian-v1' - dolly_12b_v2 = 'dolly-12b-v2' - qwen_turbo = 'qwen-turbo' - qwen_plus = 'qwen-plus' - qwen_max = 'qwen-max' + bailian_v1 = "bailian-v1" + dolly_12b_v2 = "dolly-12b-v2" + qwen_turbo = "qwen-turbo" + qwen_plus = "qwen-plus" + qwen_max = "qwen-max" @classmethod - def call( + # type: ignore[override] + def call( # pylint: disable=arguments-renamed # type: ignore[override] cls, model: str, prompt: Any = None, @@ -44,7 +55,7 @@ def call( messages: List[Message] = None, plugins: Union[str, Dict[str, Any]] = None, workspace: str = None, - **kwargs + **kwargs, ) -> Union[GenerationResponse, Generator[GenerationResponse, None, None]]: """Call generation model service. @@ -66,7 +77,7 @@ def call( plugins (Any): The plugin config. Can be plugins config str, or dict. **kwargs: stream(bool, `optional`): Enable server-sent events - (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 + (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 # pylint: disable=line-too-long the result will back partially[qwen-turbo,bailian-v1]. temperature(float, `optional`): Used to control the degree of randomness and diversity. Specifically, the temperature @@ -84,8 +95,8 @@ def call( tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered[qwen-turbo,bailian-v1]. - top_k(int, `optional`): The size of the sample candidate set when generated. # noqa E501 - For example, when the value is 50, only the 50 highest-scoring tokens # noqa E501 + top_k(int, `optional`): The size of the sample candidate set when generated. # noqa E501 # pylint: disable=line-too-long + For example, when the value is 50, only the 50 highest-scoring tokens # noqa E501 # pylint: disable=line-too-long in a single generation form a randomly sampled candidate set. # noqa E501 The larger the value, the higher the randomness generated; # noqa E501 the smaller the value, the higher the certainty generated. # noqa E501 @@ -100,20 +111,20 @@ def call( large model product, support model: [bailian-v1]. result_format(str, `optional`): [message|text] Set result result format. # noqa E501 Default result is text - incremental_output(bool, `optional`): Used to control the streaming output mode. # noqa E501 - If true, the subsequent output will include the previously input content. # noqa E501 - Otherwise, the subsequent output will not include the previously output # noqa E501 + incremental_output(bool, `optional`): Used to control the streaming output mode. # noqa E501 # pylint: disable=line-too-long + If true, the subsequent output will include the previously input content. # noqa E501 # pylint: disable=line-too-long + Otherwise, the subsequent output will not include the previously output # noqa E501 # pylint: disable=line-too-long content. Default false. - stop(list[str] or list[list[int]], `optional`): Used to control the generation to stop # noqa E501 - when encountering setting str or token ids, the result will not include # noqa E501 + stop(list[str] or list[list[int]], `optional`): Used to control the generation to stop # noqa E501 # pylint: disable=line-too-long + when encountering setting str or token ids, the result will not include # noqa E501 # pylint: disable=line-too-long stop words or tokens. - max_tokens(int, `optional`): The maximum token num expected to be output. It should be # noqa E501 - noted that the length generated by the model will only be less than max_tokens, # noqa E501 - not necessarily equal to it. If max_tokens is set too large, the service will # noqa E501 + max_tokens(int, `optional`): The maximum token num expected to be output. It should be # noqa E501 # pylint: disable=line-too-long + noted that the length generated by the model will only be less than max_tokens, # noqa E501 # pylint: disable=line-too-long + not necessarily equal to it. If max_tokens is set too large, the service will # noqa E501 # pylint: disable=line-too-long directly prompt that the length exceeds the limit. It is generally # noqa E501 not recommended to set this value. - repetition_penalty(float, `optional`): Used to control the repeatability when generating models. # noqa E501 - Increasing repetition_penalty can reduce the duplication of model generation. # noqa E501 + repetition_penalty(float, `optional`): Used to control the repeatability when generating models. # noqa E501 # pylint: disable=line-too-long + Increasing repetition_penalty can reduce the duplication of model generation. # noqa E501 # pylint: disable=line-too-long 1.0 means no punishment. workspace (str): The dashscope workspace id. Raises: @@ -124,69 +135,95 @@ def call( Generator[GenerationResponse, None, None]]: If stream is True, return Generator, otherwise GenerationResponse. """ - if (prompt is None or not prompt) and (messages is None - or not messages): - raise InputRequired('prompt or messages is required!') + if (prompt is None or not prompt) and ( + messages is None or not messages + ): + raise InputRequired("prompt or messages is required!") if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") task_group, function = _get_task_group_and_task(__name__) if plugins is not None: - headers = kwargs.pop('headers', {}) + headers = kwargs.pop("headers", {}) if isinstance(plugins, str): - headers['X-DashScope-Plugin'] = plugins + headers["X-DashScope-Plugin"] = plugins else: - headers['X-DashScope-Plugin'] = json.dumps(plugins) - kwargs['headers'] = headers - input, parameters = cls._build_input_parameters( - model, prompt, history, messages, **kwargs) + headers["X-DashScope-Plugin"] = json.dumps(plugins) + kwargs["headers"] = headers + ( + input, # pylint: disable=redefined-builtin + parameters, + ) = cls._build_input_parameters( + model, + prompt, + history, + messages, + **kwargs, + ) - is_stream = parameters.get('stream', False) + is_stream = parameters.get("stream", False) # Check if we need to merge incremental output - is_incremental_output = kwargs.get('incremental_output', None) + is_incremental_output = kwargs.get("incremental_output", None) to_merge_incremental_output = False - if (ParamUtil.should_modify_incremental_output(model) and - is_stream and is_incremental_output is False): + if ( + ParamUtil.should_modify_incremental_output(model) + and is_stream + and is_incremental_output is False + ): to_merge_incremental_output = True - parameters['incremental_output'] = True + parameters["incremental_output"] = True # Pass incremental_to_full flag via headers user-agent - if 'headers' not in parameters: - parameters['headers'] = {} - flag = '1' if to_merge_incremental_output else '0' - parameters['headers']['user-agent'] = f'incremental_to_full/{flag}' - - response = super().call(model=model, - task_group=task_group, - task=Generation.task, - function=function, - api_key=api_key, - input=input, - workspace=workspace, - **parameters) + if "headers" not in parameters: + parameters["headers"] = {} + flag = "1" if to_merge_incremental_output else "0" + parameters["headers"]["user-agent"] = f"incremental_to_full/{flag}" + + response = super().call( + model=model, + task_group=task_group, + task=Generation.task, + function=function, + api_key=api_key, + input=input, + workspace=workspace, + **parameters, + ) if is_stream: if to_merge_incremental_output: # Extract n parameter for merge logic - n = parameters.get('n', 1) + n = parameters.get("n", 1) return cls._merge_generation_response(response, n) else: - return (GenerationResponse.from_api_response(rsp) - for rsp in response) + return ( + GenerationResponse.from_api_response(rsp) + for rsp in response + ) else: return GenerationResponse.from_api_response(response) @classmethod - def _build_input_parameters(cls, model, prompt, history, messages, - **kwargs): + def _build_input_parameters( + cls, + model, + prompt, + history, + messages, + **kwargs, + ): if model == Generation.Models.qwen_v1: logger.warning( - 'Model %s is deprecated, use %s instead!' % - (Generation.Models.qwen_v1, Generation.Models.qwen_turbo)) + "Model %s is deprecated, use %s instead!", + Generation.Models.qwen_v1, + Generation.Models.qwen_turbo, + ) if model == Generation.Models.qwen_plus_v1: logger.warning( - 'Model %s is deprecated, use %s instead!' % - (Generation.Models.qwen_plus_v1, Generation.Models.qwen_plus)) + "Model %s is deprecated, use %s instead!", + Generation.Models.qwen_plus_v1, + Generation.Models.qwen_plus, + ) parameters = {} - input = {} + input = {} # pylint: disable=redefined-builtin if history is not None: logger.warning(DEPRECATED_MESSAGE) input[HISTORY] = history @@ -195,31 +232,40 @@ def _build_input_parameters(cls, model, prompt, history, messages, elif messages is not None: msgs = copy.deepcopy(messages) if prompt is not None and prompt: - msgs.append({'role': Role.USER, 'content': prompt}) + msgs.append({"role": Role.USER, "content": prompt}) input = {MESSAGES: msgs} else: input[PROMPT] = prompt - if model.startswith('qwen'): - enable_search = kwargs.pop('enable_search', False) + if model.startswith("qwen"): + enable_search = kwargs.pop("enable_search", False) if enable_search: - parameters['enable_search'] = enable_search - elif model.startswith('bailian'): - customized_model_id = kwargs.pop('customized_model_id', None) + parameters["enable_search"] = enable_search + elif model.startswith("bailian"): + customized_model_id = kwargs.pop("customized_model_id", None) if customized_model_id is None: - raise InputRequired('customized_model_id is required for %s' % - model) + raise InputRequired( + f"customized_model_id is required for {model}", + ) input[CUSTOMIZED_MODEL_ID] = customized_model_id return input, {**parameters, **kwargs} @classmethod - def _merge_generation_response(cls, response, n=1) -> Generator[GenerationResponse, None, None]: - """Merge incremental response chunks to simulate non-incremental output.""" + def _merge_generation_response( + cls, + response, + n=1, + ) -> Generator[GenerationResponse, None, None]: + """Merge incremental response chunks to simulate non-incremental output.""" # noqa: E501 accumulated_data = {} for rsp in response: parsed_response = GenerationResponse.from_api_response(rsp) - result = merge_single_response(parsed_response, accumulated_data, n) + result = merge_single_response( + parsed_response, + accumulated_data, + n, + ) if result is True: yield parsed_response elif isinstance(result, list): @@ -229,24 +275,28 @@ def _merge_generation_response(cls, response, n=1) -> Generator[GenerationRespon class AioGeneration(BaseAioApi): - task = 'text-generation' + task = "text-generation" """API for AI-Generated Content(AIGC) models. """ + class Models: """@deprecated, use qwen_turbo instead""" - qwen_v1 = 'qwen-v1' + + qwen_v1 = "qwen-v1" """@deprecated, use qwen_plus instead""" - qwen_plus_v1 = 'qwen-plus-v1' + qwen_plus_v1 = "qwen-plus-v1" - bailian_v1 = 'bailian-v1' - dolly_12b_v2 = 'dolly-12b-v2' - qwen_turbo = 'qwen-turbo' - qwen_plus = 'qwen-plus' - qwen_max = 'qwen-max' + bailian_v1 = "bailian-v1" + dolly_12b_v2 = "dolly-12b-v2" + qwen_turbo = "qwen-turbo" + qwen_plus = "qwen-plus" + qwen_max = "qwen-max" + # type: ignore[override] @classmethod - async def call( + async def call( # type: ignore[override] # pylint: disable=arguments-renamed # noqa: E501 + # type: ignore[override] cls, model: str, prompt: Any = None, @@ -255,7 +305,7 @@ async def call( messages: List[Message] = None, plugins: Union[str, Dict[str, Any]] = None, workspace: str = None, - **kwargs + **kwargs, ) -> Union[GenerationResponse, AsyncGenerator[GenerationResponse, None]]: """Call generation model service. @@ -277,7 +327,7 @@ async def call( plugins (Any): The plugin config. Can be plugins config str, or dict. **kwargs: stream(bool, `optional`): Enable server-sent events - (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 + (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 # pylint: disable=line-too-long the result will back partially[qwen-turbo,bailian-v1]. temperature(float, `optional`): Used to control the degree of randomness and diversity. Specifically, the temperature @@ -295,8 +345,8 @@ async def call( tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered[qwen-turbo,bailian-v1]. - top_k(int, `optional`): The size of the sample candidate set when generated. # noqa E501 - For example, when the value is 50, only the 50 highest-scoring tokens # noqa E501 + top_k(int, `optional`): The size of the sample candidate set when generated. # noqa E501 # pylint: disable=line-too-long + For example, when the value is 50, only the 50 highest-scoring tokens # noqa E501 # pylint: disable=line-too-long in a single generation form a randomly sampled candidate set. # noqa E501 The larger the value, the higher the randomness generated; # noqa E501 the smaller the value, the higher the certainty generated. # noqa E501 @@ -311,20 +361,20 @@ async def call( large model product, support model: [bailian-v1]. result_format(str, `optional`): [message|text] Set result result format. # noqa E501 Default result is text - incremental_output(bool, `optional`): Used to control the streaming output mode. # noqa E501 - If true, the subsequent output will include the previously input content. # noqa E501 - Otherwise, the subsequent output will not include the previously output # noqa E501 + incremental_output(bool, `optional`): Used to control the streaming output mode. # noqa E501 # pylint: disable=line-too-long + If true, the subsequent output will include the previously input content. # noqa E501 # pylint: disable=line-too-long + Otherwise, the subsequent output will not include the previously output # noqa E501 # pylint: disable=line-too-long content. Default false. - stop(list[str] or list[list[int]], `optional`): Used to control the generation to stop # noqa E501 - when encountering setting str or token ids, the result will not include # noqa E501 + stop(list[str] or list[list[int]], `optional`): Used to control the generation to stop # noqa E501 # pylint: disable=line-too-long + when encountering setting str or token ids, the result will not include # noqa E501 # pylint: disable=line-too-long stop words or tokens. - max_tokens(int, `optional`): The maximum token num expected to be output. It should be # noqa E501 - noted that the length generated by the model will only be less than max_tokens, # noqa E501 - not necessarily equal to it. If max_tokens is set too large, the service will # noqa E501 + max_tokens(int, `optional`): The maximum token num expected to be output. It should be # noqa E501 # pylint: disable=line-too-long + noted that the length generated by the model will only be less than max_tokens, # noqa E501 # pylint: disable=line-too-long + not necessarily equal to it. If max_tokens is set too large, the service will # noqa E501 # pylint: disable=line-too-long directly prompt that the length exceeds the limit. It is generally # noqa E501 not recommended to set this value. - repetition_penalty(float, `optional`): Used to control the repeatability when generating models. # noqa E501 - Increasing repetition_penalty can reduce the duplication of model generation. # noqa E501 + repetition_penalty(float, `optional`): Used to control the repeatability when generating models. # noqa E501 # pylint: disable=line-too-long + Increasing repetition_penalty can reduce the duplication of model generation. # noqa E501 # pylint: disable=line-too-long 1.0 means no punishment. workspace (str): The dashscope workspace id. Raises: @@ -335,49 +385,64 @@ async def call( AsyncGenerator[GenerationResponse, None]]: If stream is True, return AsyncGenerator, otherwise GenerationResponse. """ - if (prompt is None or not prompt) and (messages is None - or not messages): - raise InputRequired('prompt or messages is required!') + if (prompt is None or not prompt) and ( + messages is None or not messages + ): + raise InputRequired("prompt or messages is required!") if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") task_group, function = _get_task_group_and_task(__name__) if plugins is not None: - headers = kwargs.pop('headers', {}) + headers = kwargs.pop("headers", {}) if isinstance(plugins, str): - headers['X-DashScope-Plugin'] = plugins + headers["X-DashScope-Plugin"] = plugins else: - headers['X-DashScope-Plugin'] = json.dumps(plugins) - kwargs['headers'] = headers - input, parameters = Generation._build_input_parameters( - model, prompt, history, messages, **kwargs) + headers["X-DashScope-Plugin"] = json.dumps(plugins) + kwargs["headers"] = headers + # pylint: disable=protected-access + ( + input, # pylint: disable=redefined-builtin + parameters, + ) = Generation._build_input_parameters( + model, + prompt, + history, + messages, + **kwargs, + ) - is_stream = parameters.get('stream', False) + is_stream = parameters.get("stream", False) # Check if we need to merge incremental output - is_incremental_output = kwargs.get('incremental_output', None) + is_incremental_output = kwargs.get("incremental_output", None) to_merge_incremental_output = False - if (ParamUtil.should_modify_incremental_output(model) and - is_stream and is_incremental_output is False): + if ( + ParamUtil.should_modify_incremental_output(model) + and is_stream + and is_incremental_output is False + ): to_merge_incremental_output = True - parameters['incremental_output'] = True + parameters["incremental_output"] = True # Pass incremental_to_full flag via headers user-agent - if 'headers' not in parameters: - parameters['headers'] = {} - flag = '1' if to_merge_incremental_output else '0' - parameters['headers']['user-agent'] = f'incremental_to_full/{flag}' - - response = await super().call(model=model, - task_group=task_group, - task=Generation.task, - function=function, - api_key=api_key, - input=input, - workspace=workspace, - **parameters) + if "headers" not in parameters: + parameters["headers"] = {} + flag = "1" if to_merge_incremental_output else "0" + parameters["headers"]["user-agent"] = f"incremental_to_full/{flag}" + + response = await super().call( + model=model, + task_group=task_group, + task=Generation.task, + function=function, + api_key=api_key, + input=input, + workspace=workspace, + **parameters, + ) if is_stream: if to_merge_incremental_output: # Extract n parameter for merge logic - n = parameters.get('n', 1) + n = parameters.get("n", 1) return cls._merge_generation_response(response, n) else: return cls._stream_responses(response) @@ -385,20 +450,31 @@ async def call( return GenerationResponse.from_api_response(response) @classmethod - async def _stream_responses(cls, response) -> AsyncGenerator[GenerationResponse, None]: + async def _stream_responses( + cls, + response, + ) -> AsyncGenerator[GenerationResponse, None]: """Convert async response stream to GenerationResponse stream.""" # Type hint: when stream=True, response is actually an AsyncIterable async for rsp in response: # type: ignore yield GenerationResponse.from_api_response(rsp) @classmethod - async def _merge_generation_response(cls, response, n=1) -> AsyncGenerator[GenerationResponse, None]: + async def _merge_generation_response( + cls, + response, + n=1, + ) -> AsyncGenerator[GenerationResponse, None]: """Async version of merge incremental response chunks.""" accumulated_data = {} async for rsp in response: # type: ignore parsed_response = GenerationResponse.from_api_response(rsp) - result = merge_single_response(parsed_response, accumulated_data, n) + result = merge_single_response( + parsed_response, + accumulated_data, + n, + ) if result is True: yield parsed_response elif isinstance(result, list): diff --git a/dashscope/aigc/image_generation.py b/dashscope/aigc/image_generation.py index c4952c9..7c23729 100644 --- a/dashscope/aigc/image_generation.py +++ b/dashscope/aigc/image_generation.py @@ -1,10 +1,19 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Generator, List, Union, AsyncGenerator -from dashscope.api_entities.dashscope_response import (GenerationResponse, - Message, DashScopeAPIResponse, ImageGenerationResponse) -from dashscope.client.base_api import BaseAioApi, BaseApi, BaseAsyncApi, BaseAsyncAioApi +from dashscope.api_entities.dashscope_response import ( + Message, + DashScopeAPIResponse, + ImageGenerationResponse, +) +from dashscope.client.base_api import ( + BaseAioApi, + BaseApi, + BaseAsyncApi, + BaseAsyncAioApi, +) from dashscope.common.error import InputRequired, ModelRequired from dashscope.common.utils import _get_task_group_and_task from dashscope.utils.oss_utils import preprocess_message_element @@ -13,26 +22,29 @@ class ImageGeneration(BaseApi, BaseAsyncApi): - sync_task = 'multimodal-generation' - async_task = 'image-generation' - function = 'generation' + sync_task = "multimodal-generation" + async_task = "image-generation" + function = "generation" """API for AI-Generated Content(AIGC) models. """ - class Models: - wan2_6_image = 'wan2.6-image' - wan2_6_t2i = 'wan2.6-t2i' + class Models: + wan2_6_image = "wan2.6-image" + wan2_6_t2i = "wan2.6-t2i" @classmethod - def call( + def call( # type: ignore[override] cls, model: str, api_key: str = None, messages: List[Message] = None, workspace: str = None, - **kwargs - ) -> Union[ImageGenerationResponse, Generator[ImageGenerationResponse, None, None]]: + **kwargs, + ) -> Union[ + ImageGenerationResponse, + Generator[ImageGenerationResponse, None, None], + ]: """Call generation model service. Args: @@ -46,7 +58,7 @@ def call( {'role': 'assistant', 'content': 'Suitable for outings'}] **kwargs: stream(bool, `optional`): Enable server-sent events - (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 + (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 # pylint: disable=line-too-long the result will back partially[qwen-turbo,bailian-v1]. workspace (str): The dashscope workspace id. Raises: @@ -58,80 +70,93 @@ def call( stream is True, return Generator, otherwise ImageGenerationResponse. """ if messages is None or not messages: - raise InputRequired('messages is required!') + raise InputRequired("messages is required!") if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") task_group, _ = _get_task_group_and_task(__name__) _input = {} if messages is not None and messages: - has_upload = cls._preprocess_messages(model, messages, api_key) + has_upload = cls._preprocess_messages(model, messages, api_key) # type: ignore[arg-type] # pylint: disable=line-too-long # noqa: E501 if has_upload: - headers = kwargs.pop('headers', {}) - headers['X-DashScope-OssResourceResolve'] = 'enable' - kwargs['headers'] = headers + headers = kwargs.pop("headers", {}) + headers["X-DashScope-OssResourceResolve"] = "enable" + kwargs["headers"] = headers - _input.update({'messages': messages}) + _input.update({"messages": messages}) # Check if we need to merge incremental output - is_incremental_output = kwargs.get('incremental_output', None) - is_stream = kwargs.get('stream', False) + is_incremental_output = kwargs.get("incremental_output", None) + is_stream = kwargs.get("stream", False) to_merge_incremental_output = False - if (ParamUtil.should_modify_incremental_output(model) and - is_stream and is_incremental_output is not None and is_incremental_output is False): + if ( + ParamUtil.should_modify_incremental_output(model) + and is_stream + and is_incremental_output is not None + and is_incremental_output is False + ): to_merge_incremental_output = True - kwargs['incremental_output'] = True + kwargs["incremental_output"] = True # Pass incremental_to_full flag via headers user-agent - if 'headers' not in kwargs: - kwargs['headers'] = {} + if "headers" not in kwargs: + kwargs["headers"] = {} - flag = '1' if to_merge_incremental_output else '0' - kwargs['headers']['user-agent'] = f'incremental_to_full/{flag}' - if kwargs.get('is_async', False): - kwargs['headers']['X-DashScope-Async'] = 'enable' + flag = "1" if to_merge_incremental_output else "0" + kwargs["headers"]["user-agent"] = f"incremental_to_full/{flag}" + if kwargs.get("is_async", False): + kwargs["headers"]["X-DashScope-Async"] = "enable" task = cls.async_task else: task = cls.sync_task - response = super().call(model=model, - task_group=task_group, - task=task, - function=ImageGeneration.function, - api_key=api_key, - input=_input, - workspace=workspace, - **kwargs) + response = super().call( + model=model, + task_group=task_group, + task=task, + function=ImageGeneration.function, + api_key=api_key, + input=_input, + workspace=workspace, + **kwargs, + ) if is_stream: if to_merge_incremental_output: # Extract n parameter for merge logic - n = kwargs.get('n', 1) + n = kwargs.get("n", 1) return cls._merge_generation_response(response, n) else: - return (ImageGenerationResponse.from_api_response(rsp) - for rsp in response) + return ( + ImageGenerationResponse.from_api_response(rsp) + for rsp in response + ) else: return ImageGenerationResponse.from_api_response(response) @classmethod - def async_call( + def async_call( # type: ignore[override] cls, model: str, api_key: str = None, messages: List[Message] = None, workspace: str = None, - **kwargs - ) -> Union[ImageGenerationResponse, Generator[ImageGenerationResponse, None, None]]: + **kwargs, + ) -> Union[ + ImageGenerationResponse, + Generator[ImageGenerationResponse, None, None], + ]: kwargs["is_async"] = True return cls.call(model, api_key, messages, workspace, **kwargs) @classmethod - def fetch(cls, - task: Union[str, ImageGenerationResponse], - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def fetch( + cls, + task: Union[str, ImageGenerationResponse], # type: ignore[override] + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Fetch image(s) synthesis task status or result. Args: @@ -147,11 +172,13 @@ def fetch(cls, return ImageGenerationResponse.from_api_response(response) @classmethod - def wait(cls, - task: Union[str, ImageGenerationResponse], - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def wait( + cls, + task: Union[str, ImageGenerationResponse], # type: ignore[override] + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Wait for image(s) synthesis task to complete, and return the result. Args: @@ -167,11 +194,13 @@ def wait(cls, return ImageGenerationResponse.from_api_response(response) @classmethod - def cancel(cls, - task: Union[str, ImageGenerationResponse], - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def cancel( + cls, + task: Union[str, ImageGenerationResponse], # type: ignore[override] + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Cancel image synthesis task. Only tasks whose status is PENDING can be canceled. @@ -187,18 +216,20 @@ def cancel(cls, return super().cancel(task, api_key, workspace=workspace) @classmethod - def list(cls, - start_time: str = None, - end_time: str = None, - model_name: str = None, - api_key_id: str = None, - region: str = None, - status: str = None, - page_no: int = 1, - page_size: int = 10, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def list( + cls, + start_time: str = None, + end_time: str = None, + model_name: str = None, + api_key_id: str = None, + region: str = None, + status: str = None, + page_no: int = 1, + page_size: int = 10, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """List async tasks. Args: @@ -220,55 +251,73 @@ def list(cls, Returns: DashScopeAPIResponse: The response data. """ - return super().list(start_time=start_time, - end_time=end_time, - model_name=model_name, - api_key_id=api_key_id, - region=region, - status=status, - page_no=page_no, - page_size=page_size, - api_key=api_key, - workspace=workspace, - **kwargs) - + return super().list( + start_time=start_time, + end_time=end_time, + model_name=model_name, + api_key_id=api_key_id, + region=region, + status=status, + page_no=page_no, + page_size=page_size, + api_key=api_key, + workspace=workspace, + **kwargs, + ) @classmethod - def _preprocess_messages(cls, model: str, messages: List[dict], - api_key: str): + def _preprocess_messages( + cls, + model: str, + messages: List[dict], + api_key: str, + ): """ - messages = [ - { - "role": "user", - "content": [ - {"image": ""}, - {"text": ""}, - ] - } - ] + messages = [ + { + "role": "user", + "content": [ + {"image": ""}, + {"text": ""}, + ] + } + ] """ has_upload = False upload_certificate = None for message in messages: - content = message['content'] + content = message["content"] for elem in content: - if not isinstance(elem, - (int, float, bool, str, bytes, bytearray)): + if not isinstance( + elem, + (int, float, bool, str, bytes, bytearray), + ): is_upload, upload_certificate = preprocess_message_element( - model, elem, api_key, upload_certificate) + model, + elem, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload and not has_upload: has_upload = True return has_upload - @classmethod - def _merge_generation_response(cls, response, n=1) -> Generator[ImageGenerationResponse, None, None]: - """Merge incremental response chunks to simulate non-incremental output.""" + def _merge_generation_response( + cls, + response, + n=1, + ) -> Generator[ImageGenerationResponse, None, None]: + """Merge incremental response chunks to simulate non-incremental output.""" # noqa: E501 accumulated_data = {} for rsp in response: parsed_response = ImageGenerationResponse.from_api_response(rsp) - result = merge_single_response(parsed_response, accumulated_data, n) + result = merge_single_response( + parsed_response, + accumulated_data, + n, + ) if result is True: yield parsed_response elif isinstance(result, list): @@ -278,25 +327,29 @@ def _merge_generation_response(cls, response, n=1) -> Generator[ImageGenerationR class AioImageGeneration(BaseAioApi, BaseAsyncAioApi): - sync_task = 'multimodal-generation' - async_task = 'image-generation' - function = 'generation' + sync_task = "multimodal-generation" + async_task = "image-generation" + function = "generation" """API for AI-Generated Content(AIGC) models. """ + class Models: - wan2_6_image = 'wan2.6-image' - wan2_6_t2i = 'wan2.6-t2i' + wan2_6_image = "wan2.6-image" + wan2_6_t2i = "wan2.6-t2i" @classmethod - async def call( + async def call( # type: ignore[override] cls, model: str, api_key: str = None, messages: List[Message] = None, workspace: str = None, - **kwargs - ) -> Union[ImageGenerationResponse, AsyncGenerator[ImageGenerationResponse, None]]: + **kwargs, + ) -> Union[ + ImageGenerationResponse, + AsyncGenerator[ImageGenerationResponse, None], + ]: """Call generation model service. Args: @@ -310,7 +363,7 @@ async def call( {'role': 'assistant', 'content': 'Suitable for outings'}] **kwargs: stream(bool, `optional`): Enable server-sent events - (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 + (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 # pylint: disable=line-too-long the result will back partially[qwen-turbo,bailian-v1]. workspace (str): The dashscope workspace id. Raises: @@ -322,81 +375,90 @@ async def call( stream is True, return AsyncGenerator, otherwise ImageGenerationResponse. """ if messages is None or not messages: - raise InputRequired('messages is required!') + raise InputRequired("messages is required!") if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") task_group, _ = _get_task_group_and_task(__name__) _input = {} - if messages is not None and messages: - has_upload = cls._preprocess_messages(model, messages, api_key) + if messages is not None and messages: # type: ignore + has_upload = cls._preprocess_messages(model, messages, api_key) # type: ignore[arg-type] # pylint: disable=line-too-long # noqa: E501 if has_upload: - headers = kwargs.pop('headers', {}) - headers['X-DashScope-OssResourceResolve'] = 'enable' - kwargs['headers'] = headers + headers = kwargs.pop("headers", {}) + headers["X-DashScope-OssResourceResolve"] = "enable" + kwargs["headers"] = headers - _input.update({'messages': messages}) + _input.update({"messages": messages}) # Check if we need to merge incremental output - is_incremental_output = kwargs.get('incremental_output', None) - is_stream = kwargs.get('stream', False) + is_incremental_output = kwargs.get("incremental_output", None) + is_stream = kwargs.get("stream", False) to_merge_incremental_output = False - if (ParamUtil.should_modify_incremental_output(model) and - is_stream and is_incremental_output is not None and is_incremental_output is False): + if ( + ParamUtil.should_modify_incremental_output(model) + and is_stream + and is_incremental_output is not None + and is_incremental_output is False + ): to_merge_incremental_output = True - kwargs['incremental_output'] = True + kwargs["incremental_output"] = True # Pass incremental_to_full flag via headers user-agent - if 'headers' not in kwargs: - kwargs['headers'] = {} + if "headers" not in kwargs: + kwargs["headers"] = {} - flag = '1' if to_merge_incremental_output else '0' - kwargs['headers']['user-agent'] = f'incremental_to_full/{flag}' - if kwargs.get('is_async', False): - kwargs['headers']['X-DashScope-Async'] = 'enable' + flag = "1" if to_merge_incremental_output else "0" + kwargs["headers"]["user-agent"] = f"incremental_to_full/{flag}" + if kwargs.get("is_async", False): + kwargs["headers"]["X-DashScope-Async"] = "enable" task = cls.async_task else: task = cls.sync_task - response = await super().call(model=model, - task_group=task_group, - task=task, - function=AioImageGeneration.function, - api_key=api_key, - input=_input, - workspace=workspace, - **kwargs) + response = await super().call( + model=model, + task_group=task_group, + task=task, + function=AioImageGeneration.function, + api_key=api_key, + input=_input, + workspace=workspace, + **kwargs, + ) if is_stream: if to_merge_incremental_output: # Extract n parameter for merge logic - n = kwargs.get('n', 1) + n = kwargs.get("n", 1) return cls._merge_generation_response(response, n) else: return cls._stream_responses(response) else: return ImageGenerationResponse.from_api_response(response) - @classmethod - async def async_call( + async def async_call( # type: ignore[override] cls, model: str, api_key: str = None, messages: List[Message] = None, workspace: str = None, - **kwargs - ) -> Union[ImageGenerationResponse, AsyncGenerator[ImageGenerationResponse, None]]: + **kwargs, + ) -> Union[ + ImageGenerationResponse, + AsyncGenerator[ImageGenerationResponse, None], + ]: kwargs["is_async"] = True return await cls.call(model, api_key, messages, workspace, **kwargs) - @classmethod - async def fetch(cls, - task: Union[str, ImageGenerationResponse], - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + async def fetch( + cls, + task: Union[str, ImageGenerationResponse], # type: ignore[override] + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Fetch image(s) synthesis task status or result. Args: @@ -408,15 +470,21 @@ async def fetch(cls, Returns: DashScopeAPIResponse: The task status or result. """ - response = await super().fetch(task, api_key=api_key, workspace=workspace) + response = await super().fetch( + task, + api_key=api_key, + workspace=workspace, + ) return ImageGenerationResponse.from_api_response(response) @classmethod - async def wait(cls, - task: Union[str, ImageGenerationResponse], - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + async def wait( + cls, + task: Union[str, ImageGenerationResponse], # type: ignore[override] + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Wait for image(s) synthesis task to complete, and return the result. Args: @@ -432,11 +500,13 @@ async def wait(cls, return ImageGenerationResponse.from_api_response(response) @classmethod - async def cancel(cls, - task: Union[str, ImageGenerationResponse], - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + async def cancel( + cls, + task: Union[str, ImageGenerationResponse], # type: ignore[override] + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Cancel image synthesis task. Only tasks whose status is PENDING can be canceled. @@ -452,18 +522,20 @@ async def cancel(cls, return await super().cancel(task, api_key, workspace=workspace) @classmethod - async def list(cls, - start_time: str = None, - end_time: str = None, - model_name: str = None, - api_key_id: str = None, - region: str = None, - status: str = None, - page_no: int = 1, - page_size: int = 10, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + async def list( + cls, + start_time: str = None, + end_time: str = None, + model_name: str = None, + api_key_id: str = None, + region: str = None, + status: str = None, + page_no: int = 1, + page_size: int = 10, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """List async tasks. Args: @@ -485,61 +557,84 @@ async def list(cls, Returns: DashScopeAPIResponse: The response data. """ - return await super().list(start_time=start_time, - end_time=end_time, - model_name=model_name, - api_key_id=api_key_id, - region=region, - status=status, - page_no=page_no, - page_size=page_size, - api_key=api_key, - workspace=workspace, - **kwargs) + return await super().list( + start_time=start_time, + end_time=end_time, + model_name=model_name, + api_key_id=api_key_id, + region=region, + status=status, + page_no=page_no, + page_size=page_size, + api_key=api_key, + workspace=workspace, + **kwargs, + ) @classmethod - def _preprocess_messages(cls, model: str, messages: List[dict], - api_key: str): + def _preprocess_messages( + cls, + model: str, + messages: List[dict], + api_key: str, + ): """ - messages = [ - { - "role": "user", - "content": [ - {"image": ""}, - {"text": ""}, - ] - } - ] + messages = [ + { + "role": "user", + "content": [ + {"image": ""}, + {"text": ""}, + ] + } + ] """ has_upload = False upload_certificate = None for message in messages: - content = message['content'] + content = message["content"] for elem in content: - if not isinstance(elem, - (int, float, bool, str, bytes, bytearray)): + if not isinstance( + elem, + (int, float, bool, str, bytes, bytearray), + ): is_upload, upload_certificate = preprocess_message_element( - model, elem, api_key, upload_certificate) + model, + elem, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload and not has_upload: has_upload = True return has_upload @classmethod - async def _stream_responses(cls, response) -> AsyncGenerator[ImageGenerationResponse, None]: + async def _stream_responses( + cls, + response, + ) -> AsyncGenerator[ImageGenerationResponse, None]: """Convert async response stream to ImageGenerationResponse stream.""" # Type hint: when stream=True, response is actually an AsyncIterable async for rsp in response: # type: ignore yield ImageGenerationResponse.from_api_response(rsp) @classmethod - async def _merge_generation_response(cls, response, n=1) -> AsyncGenerator[ImageGenerationResponse, None]: + async def _merge_generation_response( + cls, + response, + n=1, + ) -> AsyncGenerator[ImageGenerationResponse, None]: """Async version of merge incremental response chunks.""" accumulated_data = {} async for rsp in response: # type: ignore parsed_response = ImageGenerationResponse.from_api_response(rsp) - result = merge_single_response(parsed_response, accumulated_data, n) + result = merge_single_response( + parsed_response, + accumulated_data, + n, + ) if result is True: yield parsed_response elif isinstance(result, list): diff --git a/dashscope/aigc/image_synthesis.py b/dashscope/aigc/image_synthesis.py index 8d6d245..41bd060 100644 --- a/dashscope/aigc/image_synthesis.py +++ b/dashscope/aigc/image_synthesis.py @@ -1,10 +1,18 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict, List, Union -from dashscope.api_entities.dashscope_response import (DashScopeAPIResponse, - ImageSynthesisResponse) -from dashscope.client.base_api import BaseAsyncApi, BaseApi, BaseAsyncAioApi, BaseAioApi +from dashscope.api_entities.dashscope_response import ( + DashScopeAPIResponse, + ImageSynthesisResponse, +) +from dashscope.client.base_api import ( + BaseAsyncApi, + BaseApi, + BaseAsyncAioApi, + BaseAioApi, +) from dashscope.common.constants import IMAGES, NEGATIVE_PROMPT, PROMPT from dashscope.common.error import InputRequired from dashscope.common.utils import _get_task_group_and_task @@ -12,31 +20,34 @@ class ImageSynthesis(BaseAsyncApi): - task = 'text2image' + task = "text2image" """API for image synthesis. """ + class Models: - wanx_v1 = 'wanx-v1' - wanx_sketch_to_image_v1 = 'wanx-sketch-to-image-v1' + wanx_v1 = "wanx-v1" + wanx_sketch_to_image_v1 = "wanx-sketch-to-image-v1" - wanx_2_1_imageedit = 'wanx2.1-imageedit' + wanx_2_1_imageedit = "wanx2.1-imageedit" @classmethod - def call(cls, - model: str, - prompt: Any, - negative_prompt: Any = None, - images: List[str] = None, - api_key: str = None, - sketch_image_url: str = None, - ref_img: str = None, - workspace: str = None, - extra_input: Dict = None, - task: str = None, - function: str = None, - mask_image_url: str = None, - base_image_url: str = None, - **kwargs) -> ImageSynthesisResponse: + def call( # type: ignore[override] + cls, + model: str, + prompt: Any, + negative_prompt: Any = None, + images: List[str] = None, + api_key: str = None, + sketch_image_url: str = None, + ref_img: str = None, + workspace: str = None, + extra_input: Dict = None, + task: str = None, + function: str = None, + mask_image_url: str = None, + base_image_url: str = None, + **kwargs, + ) -> ImageSynthesisResponse: """Call image(s) synthesis service and get result. Args: @@ -56,7 +67,7 @@ def call(cls, colorization,super_resolution,expand,remove_watermaker,doodle, description_edit_with_mask,description_edit,stylization_local,stylization_all base_image_url (str): Enter the URL address of the target edited image. - mask_image_url (str): Provide the URL address of the image of the marked area by the user. It should be consistent with the image resolution of the base_image_url. + mask_image_url (str): Provide the URL address of the image of the marked area by the user. It should be consistent with the image resolution of the base_image_url. # pylint: disable=line-too-long **kwargs: n(int, `optional`): Number of images to synthesis. size(str, `optional`): The output image(s) size(width*height). @@ -74,66 +85,92 @@ def call(cls, Returns: ImageSynthesisResponse: The image(s) synthesis result. """ - return super().call(model, - prompt, - negative_prompt, - images, - api_key=api_key, - sketch_image_url=sketch_image_url, - ref_img=ref_img, - workspace=workspace, - extra_input=extra_input, - task=task, - function=function, - mask_image_url=mask_image_url, - base_image_url=base_image_url, - **kwargs) + return super().call( # type: ignore[return-value] + model, + prompt, + negative_prompt, + images, + api_key=api_key, + sketch_image_url=sketch_image_url, + ref_img=ref_img, + workspace=workspace, + extra_input=extra_input, + task=task, + function=function, + mask_image_url=mask_image_url, + base_image_url=base_image_url, + **kwargs, + ) @classmethod - def sync_call(cls, - model: str, - prompt: Any, - negative_prompt: Any = None, - images: List[str] = None, - api_key: str = None, - sketch_image_url: str = None, - ref_img: str = None, - workspace: str = None, - extra_input: Dict = None, - task: str = None, - function: str = None, - mask_image_url: str = None, - base_image_url: str = None, - **kwargs) -> ImageSynthesisResponse: + def sync_call( + cls, + model: str, + prompt: Any, + negative_prompt: Any = None, + images: List[str] = None, + api_key: str = None, + sketch_image_url: str = None, + ref_img: str = None, + workspace: str = None, + extra_input: Dict = None, + task: str = None, + function: str = None, + mask_image_url: str = None, + base_image_url: str = None, + **kwargs, + ) -> ImageSynthesisResponse: """ - Note: This method currently now only supports wan2.2-t2i-flash and wan2.2-t2i-plus. - Using other models will result in an error,More raw image models may be added for use later + Note: This method currently now only supports wan2.2-t2i-flash and wan2.2-t2i-plus. # noqa: E501 # pylint: disable=line-too-long + Using other models will result in an error,More raw image models may be added for use later # pylint: disable=line-too-long """ task_group, f = _get_task_group_and_task(__name__) - inputs, kwargs, task = cls._get_input(model, prompt, negative_prompt, - images, api_key, sketch_image_url, - ref_img, extra_input, task, function, - mask_image_url, base_image_url, **kwargs) - response = BaseApi.call(model, inputs, task_group, task, f, api_key, workspace, **kwargs) + inputs, kwargs, task = cls._get_input( + model, + prompt, + negative_prompt, + images, + api_key, + sketch_image_url, + ref_img, + extra_input, + task, + function, + mask_image_url, + base_image_url, + **kwargs, + ) + response = BaseApi.call( + model, + inputs, + task_group, + task, + f, + api_key, + workspace, + **kwargs, + ) return ImageSynthesisResponse.from_api_response(response) @classmethod - def _get_input(cls, - model: str, - prompt: Any, - negative_prompt: Any = None, - images: List[str] = None, - api_key: str = None, - sketch_image_url: str = None, - ref_img: str = None, - extra_input: Dict = None, - task: str = None, - function: str = None, - mask_image_url: str = None, - base_image_url: str = None, - **kwargs): + def _get_input( # pylint: disable=too-many-branches + cls, + model: str, + prompt: Any, + negative_prompt: Any = None, + images: List[str] = None, + api_key: str = None, + sketch_image_url: str = None, + ref_img: str = None, + extra_input: Dict = None, + task: str = None, + function: str = None, + mask_image_url: str = None, + base_image_url: str = None, + **kwargs, + ): if prompt is None or not prompt: - raise InputRequired('prompt is required!') + raise InputRequired("prompt is required!") inputs = {PROMPT: prompt} has_upload = False upload_certificate = None @@ -143,49 +180,85 @@ def _get_input(cls, if images is not None and images and len(images) > 0: new_images = [] for image in images: - is_upload, new_image, upload_certificate = check_and_upload_local( - model, image, api_key, upload_certificate) + ( + is_upload, + new_image, + upload_certificate, + ) = check_and_upload_local( + model, + image, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload: has_upload = True new_images.append(new_image) inputs[IMAGES] = new_images if sketch_image_url is not None and sketch_image_url: - is_upload, sketch_image_url, upload_certificate = check_and_upload_local( - model, sketch_image_url, api_key, upload_certificate) + ( + is_upload, + sketch_image_url, + upload_certificate, + ) = check_and_upload_local( + model, + sketch_image_url, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload: has_upload = True - inputs['sketch_image_url'] = sketch_image_url + inputs["sketch_image_url"] = sketch_image_url if ref_img is not None and ref_img: is_upload, ref_img, upload_certificate = check_and_upload_local( - model, ref_img, api_key, upload_certificate) + model, + ref_img, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload: has_upload = True - inputs['ref_img'] = ref_img + inputs["ref_img"] = ref_img if function is not None and function: - inputs['function'] = function + inputs["function"] = function if mask_image_url is not None and mask_image_url: - is_upload, res_mask_image_url, upload_certificate = check_and_upload_local( - model, mask_image_url, api_key, upload_certificate) + ( + is_upload, + res_mask_image_url, + upload_certificate, + ) = check_and_upload_local( + model, + mask_image_url, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload: has_upload = True - inputs['mask_image_url'] = res_mask_image_url + inputs["mask_image_url"] = res_mask_image_url if base_image_url is not None and base_image_url: - is_upload, res_base_image_url, upload_certificate = check_and_upload_local( - model, base_image_url, api_key, upload_certificate) + ( + is_upload, + res_base_image_url, + upload_certificate, + ) = check_and_upload_local( + model, + base_image_url, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload: has_upload = True - inputs['base_image_url'] = res_base_image_url + inputs["base_image_url"] = res_base_image_url if extra_input is not None and extra_input: inputs = {**inputs, **extra_input} if has_upload: - headers = kwargs.pop('headers', {}) - headers['X-DashScope-OssResourceResolve'] = 'enable' - kwargs['headers'] = headers + headers = kwargs.pop("headers", {}) + headers["X-DashScope-OssResourceResolve"] = "enable" + kwargs["headers"] = headers def __get_i2i_task(task, model) -> str: # 处理task参数:优先使用有效的task值 @@ -194,8 +267,8 @@ def __get_i2i_task(task, model) -> str: # 根据model确定任务类型 if model is not None and model != "": - if 'imageedit' in model or "wan2.5-i2i" in model: - return 'image2image' + if "imageedit" in model or "wan2.5-i2i" in model: + return "image2image" # 默认返回文本到图像任务 return ImageSynthesis.task @@ -205,21 +278,24 @@ def __get_i2i_task(task, model) -> str: return inputs, kwargs, task @classmethod - def async_call(cls, - model: str, - prompt: Any, - negative_prompt: Any = None, - images: List[str] = None, - api_key: str = None, - sketch_image_url: str = None, - ref_img: str = None, - workspace: str = None, - extra_input: Dict = None, - task: str = None, - function: str = None, - mask_image_url: str = None, - base_image_url: str = None, - **kwargs) -> ImageSynthesisResponse: + # type: ignore[override] + def async_call( # pylint: disable=arguments-renamed # type: ignore[override] # noqa: E501 + cls, + model: str, + prompt: Any, + negative_prompt: Any = None, + images: List[str] = None, + api_key: str = None, + sketch_image_url: str = None, + ref_img: str = None, + workspace: str = None, + extra_input: Dict = None, + task: str = None, + function: str = None, + mask_image_url: str = None, + base_image_url: str = None, + **kwargs, + ) -> ImageSynthesisResponse: """Create a image(s) synthesis task, and return task information. Args: @@ -237,7 +313,7 @@ def async_call(cls, colorization,super_resolution,expand,remove_watermaker,doodle, description_edit_with_mask,description_edit,stylization_local,stylization_all base_image_url (str): Enter the URL address of the target edited image. - mask_image_url (str): Provide the URL address of the image of the marked area by the user. It should be consistent with the image resolution of the base_image_url. + mask_image_url (str): Provide the URL address of the image of the marked area by the user. It should be consistent with the image resolution of the base_image_url. # pylint: disable=line-too-long **kwargs(wanx-v1): n(int, `optional`): Number of images to synthesis. size: The output image(s) size, Default 1024*1024 @@ -257,10 +333,21 @@ def async_call(cls, task id in the response. """ task_group, f = _get_task_group_and_task(__name__) - inputs, kwargs, task = cls._get_input(model, prompt, negative_prompt, - images, api_key, sketch_image_url, - ref_img, extra_input, task, function, - mask_image_url, base_image_url, **kwargs) + inputs, kwargs, task = cls._get_input( + model, + prompt, + negative_prompt, + images, + api_key, + sketch_image_url, + ref_img, + extra_input, + task, + function, + mask_image_url, + base_image_url, + **kwargs, + ) response = super().async_call( model=model, task_group=task_group, @@ -269,14 +356,17 @@ def async_call(cls, api_key=api_key, input=inputs, workspace=workspace, - **kwargs) + **kwargs, + ) return ImageSynthesisResponse.from_api_response(response) @classmethod - def fetch(cls, - task: Union[str, ImageSynthesisResponse], - api_key: str = None, - workspace: str = None) -> ImageSynthesisResponse: + def fetch( # type: ignore[override] + cls, + task: Union[str, ImageSynthesisResponse], + api_key: str = None, + workspace: str = None, + ) -> ImageSynthesisResponse: """Fetch image(s) synthesis task status or result. Args: @@ -292,10 +382,12 @@ def fetch(cls, return ImageSynthesisResponse.from_api_response(response) @classmethod - def wait(cls, - task: Union[str, ImageSynthesisResponse], - api_key: str = None, - workspace: str = None) -> ImageSynthesisResponse: + def wait( # type: ignore[override] + cls, + task: Union[str, ImageSynthesisResponse], + api_key: str = None, + workspace: str = None, + ) -> ImageSynthesisResponse: """Wait for image(s) synthesis task to complete, and return the result. Args: @@ -311,10 +403,12 @@ def wait(cls, return ImageSynthesisResponse.from_api_response(response) @classmethod - def cancel(cls, - task: Union[str, ImageSynthesisResponse], - api_key: str = None, - workspace: str = None) -> DashScopeAPIResponse: + def cancel( # type: ignore[override] + cls, + task: Union[str, ImageSynthesisResponse], + api_key: str = None, + workspace: str = None, + ) -> DashScopeAPIResponse: """Cancel image synthesis task. Only tasks whose status is PENDING can be canceled. @@ -330,18 +424,20 @@ def cancel(cls, return super().cancel(task, api_key, workspace=workspace) @classmethod - def list(cls, - start_time: str = None, - end_time: str = None, - model_name: str = None, - api_key_id: str = None, - region: str = None, - status: str = None, - page_no: int = 1, - page_size: int = 10, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def list( + cls, + start_time: str = None, + end_time: str = None, + model_name: str = None, + api_key_id: str = None, + region: str = None, + status: str = None, + page_no: int = 1, + page_size: int = 10, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """List async tasks. Args: @@ -363,35 +459,42 @@ def list(cls, Returns: DashScopeAPIResponse: The response data. """ - return super().list(start_time=start_time, - end_time=end_time, - model_name=model_name, - api_key_id=api_key_id, - region=region, - status=status, - page_no=page_no, - page_size=page_size, - api_key=api_key, - workspace=workspace, - **kwargs) + return super().list( + start_time=start_time, + end_time=end_time, + model_name=model_name, + api_key_id=api_key_id, + region=region, + status=status, + page_no=page_no, + page_size=page_size, + api_key=api_key, + workspace=workspace, + **kwargs, + ) + class AioImageSynthesis(BaseAsyncAioApi): + # type: ignore[override] @classmethod - async def call(cls, - model: str, - prompt: Any, - negative_prompt: Any = None, - images: List[str] = None, - api_key: str = None, - sketch_image_url: str = None, - ref_img: str = None, - workspace: str = None, - extra_input: Dict = None, - task: str = None, - function: str = None, - mask_image_url: str = None, - base_image_url: str = None, - **kwargs) -> ImageSynthesisResponse: + async def call( # type: ignore[override] # pylint: disable=arguments-renamed # noqa: E501 + # type: ignore[override] + cls, + model: str, + prompt: Any, + negative_prompt: Any = None, + images: List[str] = None, + api_key: str = None, + sketch_image_url: str = None, + ref_img: str = None, + workspace: str = None, + extra_input: Dict = None, + task: str = None, + function: str = None, + mask_image_url: str = None, + base_image_url: str = None, + **kwargs, + ) -> ImageSynthesisResponse: """Call image(s) synthesis service and get result. Args: @@ -411,7 +514,7 @@ async def call(cls, colorization,super_resolution,expand,remove_watermaker,doodle, description_edit_with_mask,description_edit,stylization_local,stylization_all base_image_url (str): Enter the URL address of the target edited image. - mask_image_url (str): Provide the URL address of the image of the marked area by the user. It should be consistent with the image resolution of the base_image_url. + mask_image_url (str): Provide the URL address of the image of the marked area by the user. It should be consistent with the image resolution of the base_image_url. # pylint: disable=line-too-long **kwargs: n(int, `optional`): Number of images to synthesis. size(str, `optional`): The output image(s) size(width*height). @@ -430,59 +533,111 @@ async def call(cls, ImageSynthesisResponse: The image(s) synthesis result. """ task_group, f = _get_task_group_and_task(__name__) - inputs, kwargs, task = ImageSynthesis._get_input(model, prompt, negative_prompt, - images, api_key, sketch_image_url, - ref_img, extra_input, task, function, - mask_image_url, base_image_url, **kwargs) - response = await super().call(model, inputs, task_group, task, f, api_key, workspace, **kwargs) + # pylint: disable=protected-access + inputs, kwargs, task = ImageSynthesis._get_input( + model, + prompt, + negative_prompt, + images, + api_key, + sketch_image_url, + ref_img, + extra_input, + task, + function, + mask_image_url, + base_image_url, + **kwargs, + ) + response = await super().call( + model, + inputs, + task_group, + task, + f, + api_key, + workspace, + **kwargs, + ) return ImageSynthesisResponse.from_api_response(response) @classmethod - async def sync_call(cls, - model: str, - prompt: Any, - negative_prompt: Any = None, - images: List[str] = None, - api_key: str = None, - sketch_image_url: str = None, - ref_img: str = None, - workspace: str = None, - extra_input: Dict = None, - task: str = None, - function: str = None, - mask_image_url: str = None, - base_image_url: str = None, - **kwargs) -> ImageSynthesisResponse: + async def sync_call( + cls, + model: str, + prompt: Any, + negative_prompt: Any = None, + images: List[str] = None, + api_key: str = None, + sketch_image_url: str = None, + ref_img: str = None, + workspace: str = None, + extra_input: Dict = None, + task: str = None, + function: str = None, + mask_image_url: str = None, + base_image_url: str = None, + **kwargs, + ) -> ImageSynthesisResponse: """ - Note: This method currently now only supports wan2.2-t2i-flash and wan2.2-t2i-plus. - Using other models will result in an error,More raw image models may be added for use later + Note: This method currently now only supports wan2.2-t2i-flash and wan2.2-t2i-plus. # noqa: E501 # pylint: disable=line-too-long + Using other models will result in an error,More raw image models may be added for use later # pylint: disable=line-too-long """ task_group, f = _get_task_group_and_task(__name__) - inputs, kwargs, task = ImageSynthesis._get_input(model, prompt, negative_prompt, - images, api_key, sketch_image_url, - ref_img, extra_input, task, function, - mask_image_url, base_image_url, **kwargs) - response = await BaseAioApi.call(model, inputs, task_group, task, f, api_key, workspace, **kwargs) + # pylint: disable=protected-access + inputs, kwargs, task = ImageSynthesis._get_input( + model, + prompt, + negative_prompt, + images, + api_key, + sketch_image_url, + ref_img, + extra_input, + task, + function, + mask_image_url, + base_image_url, + **kwargs, + ) + response = await BaseAioApi.call( + model, + inputs, + task_group, + task, + f, + api_key, + workspace, + **kwargs, + ) return ImageSynthesisResponse.from_api_response(response) @classmethod - async def async_call(cls, - model: str, - prompt: Any, - negative_prompt: Any = None, - images: List[str] = None, - api_key: str = None, - sketch_image_url: str = None, - ref_img: str = None, - workspace: str = None, - extra_input: Dict = None, - task: str = None, - function: str = None, - mask_image_url: str = None, - base_image_url: str = None, - **kwargs) -> ImageSynthesisResponse: + async def async_call( # type: ignore[override] # pylint: disable=arguments-renamed # noqa: E501 + cls, + model: str, + prompt: Any, + negative_prompt: Any = None, + images: List[str] = None, + api_key: str = None, + sketch_image_url: str = None, + ref_img: str = None, + workspace: str = None, + extra_input: Dict = None, + task: str = None, + function: str = None, + mask_image_url: str = None, + base_image_url: str = None, + **kwargs, + ) -> ImageSynthesisResponse: """Create a image(s) synthesis task, and return task information. + Note: This method overrides BaseAsyncAioApi.async_call() with + renamed parameters to provide a more user-friendly API. The + generic parameters (input_data, task_group) are replaced with + domain-specific ones (prompt, negative_prompt, images, etc.). + Pylint's arguments-renamed warning is disabled for this reason. + Args: model (str): The model, reference ``Models``. prompt (Any): The prompt for image(s) synthesis. @@ -498,7 +653,7 @@ async def async_call(cls, colorization,super_resolution,expand,remove_watermaker,doodle, description_edit_with_mask,description_edit,stylization_local,stylization_all base_image_url (str): Enter the URL address of the target edited image. - mask_image_url (str): Provide the URL address of the image of the marked area by the user. It should be consistent with the image resolution of the base_image_url. + mask_image_url (str): Provide the URL address of the image of the marked area by the user. It should be consistent with the image resolution of the base_image_url. # pylint: disable=line-too-long **kwargs(wanx-v1): n(int, `optional`): Number of images to synthesis. size: The output image(s) size, Default 1024*1024 @@ -518,19 +673,42 @@ async def async_call(cls, task id in the response. """ task_group, f = _get_task_group_and_task(__name__) - inputs, kwargs, task = ImageSynthesis._get_input(model, prompt, negative_prompt, - images, api_key, sketch_image_url, - ref_img, extra_input, task, function, - mask_image_url, base_image_url, **kwargs) - response = await super().async_call(model, inputs, task_group, task, f, api_key, workspace, **kwargs) + # pylint: disable=protected-access + inputs, kwargs, task = ImageSynthesis._get_input( + model, + prompt, + negative_prompt, + images, + api_key, + sketch_image_url, + ref_img, + extra_input, + task, + function, + mask_image_url, + base_image_url, + **kwargs, + ) + response = await super().async_call( + model, + inputs, + task_group, + task, + f, + api_key, + workspace, + **kwargs, + ) return ImageSynthesisResponse.from_api_response(response) @classmethod - async def fetch(cls, - task: Union[str, ImageSynthesisResponse], - api_key: str = None, - workspace: str = None, - **kwargs,) -> ImageSynthesisResponse: + async def fetch( + cls, + task: Union[str, ImageSynthesisResponse], # type: ignore[override] + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> ImageSynthesisResponse: """Fetch image(s) synthesis task status or result. Args: @@ -542,15 +720,21 @@ async def fetch(cls, Returns: ImageSynthesisResponse: The task status or result. """ - response = await super().fetch(task, api_key=api_key, workspace=workspace) + response = await super().fetch( + task, + api_key=api_key, + workspace=workspace, + ) return ImageSynthesisResponse.from_api_response(response) @classmethod - async def wait(cls, - task: Union[str, ImageSynthesisResponse], - api_key: str = None, - workspace: str = None, - **kwargs) -> ImageSynthesisResponse: + async def wait( + cls, + task: Union[str, ImageSynthesisResponse], # type: ignore[override] + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> ImageSynthesisResponse: """Wait for image(s) synthesis task to complete, and return the result. Args: @@ -566,11 +750,13 @@ async def wait(cls, return ImageSynthesisResponse.from_api_response(response) @classmethod - async def cancel(cls, - task: Union[str, ImageSynthesisResponse], - api_key: str = None, - workspace: str = None, - **kwargs,) -> DashScopeAPIResponse: + async def cancel( + cls, + task: Union[str, ImageSynthesisResponse], # type: ignore[override] + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Cancel image synthesis task. Only tasks whose status is PENDING can be canceled. @@ -586,18 +772,20 @@ async def cancel(cls, return await super().cancel(task, api_key, workspace=workspace) @classmethod - async def list(cls, - start_time: str = None, - end_time: str = None, - model_name: str = None, - api_key_id: str = None, - region: str = None, - status: str = None, - page_no: int = 1, - page_size: int = 10, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + async def list( + cls, + start_time: str = None, + end_time: str = None, + model_name: str = None, + api_key_id: str = None, + region: str = None, + status: str = None, + page_no: int = 1, + page_size: int = 10, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """List async tasks. Args: @@ -619,14 +807,16 @@ async def list(cls, Returns: DashScopeAPIResponse: The response data. """ - return await super().list(start_time=start_time, - end_time=end_time, - model_name=model_name, - api_key_id=api_key_id, - region=region, - status=status, - page_no=page_no, - page_size=page_size, - api_key=api_key, - workspace=workspace, - **kwargs) \ No newline at end of file + return await super().list( + start_time=start_time, + end_time=end_time, + model_name=model_name, + api_key_id=api_key_id, + region=region, + status=status, + page_no=page_no, + page_size=page_size, + api_key=api_key, + workspace=workspace, + **kwargs, + ) diff --git a/dashscope/aigc/multimodal_conversation.py b/dashscope/aigc/multimodal_conversation.py index 6d82115..ac4ae51 100644 --- a/dashscope/aigc/multimodal_conversation.py +++ b/dashscope/aigc/multimodal_conversation.py @@ -1,12 +1,14 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import copy from typing import AsyncGenerator, Generator, List, Union -from dashscope.api_entities.dashscope_response import \ - MultiModalConversationResponse +from dashscope.api_entities.dashscope_response import ( + MultiModalConversationResponse, +) from dashscope.client.base_api import BaseAioApi, BaseApi -from dashscope.common.error import InputRequired, ModelRequired +from dashscope.common.error import ModelRequired from dashscope.common.utils import _get_task_group_and_task from dashscope.utils.oss_utils import preprocess_message_element from dashscope.utils.param_utils import ParamUtil @@ -14,16 +16,17 @@ class MultiModalConversation(BaseApi): - """MultiModal conversational robot interface. - """ - task = 'multimodal-generation' - function = 'generation' + """MultiModal conversational robot interface.""" + + task = "multimodal-generation" + function = "generation" class Models: - qwen_vl_chat_v1 = 'qwen-vl-chat-v1' + qwen_vl_chat_v1 = "qwen-vl-chat-v1" @classmethod - def call( + # type: ignore + def call( # pylint: disable=arguments-renamed,too-many-branches cls, model: str, messages: List = None, @@ -32,9 +35,15 @@ def call( text: str = None, voice: str = None, language_type: str = None, - **kwargs - ) -> Union[MultiModalConversationResponse, Generator[ - MultiModalConversationResponse, None, None]]: + **kwargs, + ) -> Union[ + MultiModalConversationResponse, + Generator[ + MultiModalConversationResponse, + None, + None, + ], + ]: """Call the conversation model service. Args: @@ -58,15 +67,15 @@ def call( ] api_key (str, optional): The api api_key, can be None, if None, will retrieve by rule [1]. - [1]: https://help.aliyun.com/zh/dashscope/developer-reference/api-key-settings. # noqa E501 + [1]: https://help.aliyun.com/zh/dashscope/developer-reference/api-key-settings. # noqa E501 # pylint: disable=line-too-long workspace (str): The dashscope workspace id. text (str): The text to generate. - voice (str): The voice name of qwen tts, include 'Cherry'/'Ethan'/'Sunny'/'Dylan' and so on, - you can get the total voice list : https://help.aliyun.com/zh/model-studio/qwen-tts. - language_type (str): The synthesized language type, default is 'auto', useful for [qwen3-tts]. + voice (str): The voice name of qwen tts, include 'Cherry'/'Ethan'/'Sunny'/'Dylan' and so on, # pylint: disable=line-too-long + you can get the total voice list : https://help.aliyun.com/zh/model-studio/qwen-tts. # pylint: disable=line-too-long + language_type (str): The synthesized language type, default is 'auto', useful for [qwen3-tts]. # pylint: disable=line-too-long **kwargs: stream(bool, `optional`): Enable server-sent events - (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 + (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 # pylint: disable=line-too-long the result will back partially[qwen-turbo,bailian-v1]. max_length(int, `optional`): The maximum length of tokens to generate. The token count of your prompt plus max_length @@ -89,98 +98,126 @@ def call( stream is True, return Generator, otherwise MultiModalConversationResponse. """ if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") task_group, _ = _get_task_group_and_task(__name__) - input = {} + input = {} # pylint: disable=redefined-builtin msg_copy = None if messages is not None and messages: msg_copy = copy.deepcopy(messages) has_upload = cls._preprocess_messages(model, msg_copy, api_key) if has_upload: - headers = kwargs.pop('headers', {}) - headers['X-DashScope-OssResourceResolve'] = 'enable' - kwargs['headers'] = headers + headers = kwargs.pop("headers", {}) + headers["X-DashScope-OssResourceResolve"] = "enable" + kwargs["headers"] = headers if text is not None and text: - input.update({'text': text}) + input.update({"text": text}) if voice is not None and voice: - input.update({'voice': voice}) + input.update({"voice": voice}) if language_type is not None and language_type: - input.update({'language_type': language_type}) + input.update({"language_type": language_type}) if msg_copy is not None: - input.update({'messages': msg_copy}) + input.update({"messages": msg_copy}) # type: ignore # Check if we need to merge incremental output - is_incremental_output = kwargs.get('incremental_output', None) + is_incremental_output = kwargs.get("incremental_output", None) to_merge_incremental_output = False - is_stream = kwargs.get('stream', False) - if (ParamUtil.should_modify_incremental_output(model) and - is_stream and is_incremental_output is not None and is_incremental_output is False): + is_stream = kwargs.get("stream", False) + if ( + ParamUtil.should_modify_incremental_output(model) + and is_stream + and is_incremental_output is not None + and is_incremental_output is False + ): to_merge_incremental_output = True - kwargs['incremental_output'] = True + kwargs["incremental_output"] = True # Pass incremental_to_full flag via headers user-agent - if 'headers' not in kwargs: - kwargs['headers'] = {} - flag = '1' if to_merge_incremental_output else '0' - kwargs['headers']['user-agent'] = f'incremental_to_full/{flag}' - - response = super().call(model=model, - task_group=task_group, - task=MultiModalConversation.task, - function=MultiModalConversation.function, - api_key=api_key, - input=input, - workspace=workspace, - **kwargs) + if "headers" not in kwargs: + kwargs["headers"] = {} + flag = "1" if to_merge_incremental_output else "0" + kwargs["headers"]["user-agent"] = f"incremental_to_full/{flag}" + + response = super().call( + model=model, + task_group=task_group, + task=MultiModalConversation.task, + function=MultiModalConversation.function, + api_key=api_key, + input=input, + workspace=workspace, + **kwargs, + ) if is_stream: if to_merge_incremental_output: # Extract n parameter for merge logic - n = kwargs.get('n', 1) + n = kwargs.get("n", 1) return cls._merge_multimodal_response(response, n) else: - return (MultiModalConversationResponse.from_api_response(rsp) - for rsp in response) + return ( + MultiModalConversationResponse.from_api_response(rsp) + for rsp in response + ) else: return MultiModalConversationResponse.from_api_response(response) @classmethod - def _preprocess_messages(cls, model: str, messages: List[dict], - api_key: str): + def _preprocess_messages( + cls, + model: str, + messages: List[dict], + api_key: str, + ): """ - messages = [ - { - "role": "user", - "content": [ - {"image": ""}, - {"text": ""}, - ] - } - ] + messages = [ + { + "role": "user", + "content": [ + {"image": ""}, + {"text": ""}, + ] + } + ] """ has_upload = False upload_certificate = None for message in messages: - content = message['content'] + content = message["content"] for elem in content: - if not isinstance(elem, - (int, float, bool, str, bytes, bytearray)): + if not isinstance( + elem, + (int, float, bool, str, bytes, bytearray), + ): is_upload, upload_certificate = preprocess_message_element( - model, elem, api_key, upload_certificate) + model, + elem, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload and not has_upload: has_upload = True return has_upload @classmethod - def _merge_multimodal_response(cls, response, n=1) -> Generator[MultiModalConversationResponse, None, None]: - """Merge incremental response chunks to simulate non-incremental output.""" + def _merge_multimodal_response( + cls, + response, + n=1, + ) -> Generator[MultiModalConversationResponse, None, None]: + """Merge incremental response chunks to simulate non-incremental output.""" # noqa: E501 accumulated_data = {} for rsp in response: - parsed_response = MultiModalConversationResponse.from_api_response(rsp) - result = merge_multimodal_single_response(parsed_response, accumulated_data, n) + parsed_response = MultiModalConversationResponse.from_api_response( + rsp, + ) + result = merge_multimodal_single_response( + parsed_response, + accumulated_data, + n, + ) if result is True: yield parsed_response elif isinstance(result, list): @@ -190,16 +227,16 @@ def _merge_multimodal_response(cls, response, n=1) -> Generator[MultiModalConver class AioMultiModalConversation(BaseAioApi): - """Async MultiModal conversational robot interface. - """ - task = 'multimodal-generation' - function = 'generation' + """Async MultiModal conversational robot interface.""" + + task = "multimodal-generation" + function = "generation" class Models: - qwen_vl_chat_v1 = 'qwen-vl-chat-v1' + qwen_vl_chat_v1 = "qwen-vl-chat-v1" - @classmethod - async def call( + @classmethod # type: ignore + async def call( # pylint: disable=arguments-renamed,too-many-branches cls, model: str, messages: List = None, @@ -208,9 +245,14 @@ async def call( text: str = None, voice: str = None, language_type: str = None, - **kwargs - ) -> Union[MultiModalConversationResponse, AsyncGenerator[ - MultiModalConversationResponse, None]]: + **kwargs, + ) -> Union[ + MultiModalConversationResponse, + AsyncGenerator[ + MultiModalConversationResponse, + None, + ], + ]: """Call the conversation model service asynchronously. Args: @@ -234,15 +276,15 @@ async def call( ] api_key (str, optional): The api api_key, can be None, if None, will retrieve by rule [1]. - [1]: https://help.aliyun.com/zh/dashscope/developer-reference/api-key-settings. # noqa E501 + [1]: https://help.aliyun.com/zh/dashscope/developer-reference/api-key-settings. # noqa E501 # pylint: disable=line-too-long workspace (str): The dashscope workspace id. text (str): The text to generate. - voice (str): The voice name of qwen tts, include 'Cherry'/'Ethan'/'Sunny'/'Dylan' and so on, - you can get the total voice list : https://help.aliyun.com/zh/model-studio/qwen-tts. - language_type (str): The synthesized language type, default is 'auto', useful for [qwen3-tts]. + voice (str): The voice name of qwen tts, include 'Cherry'/'Ethan'/'Sunny'/'Dylan' and so on, # pylint: disable=line-too-long + you can get the total voice list : https://help.aliyun.com/zh/model-studio/qwen-tts. # pylint: disable=line-too-long + language_type (str): The synthesized language type, default is 'auto', useful for [qwen3-tts]. # pylint: disable=line-too-long **kwargs: stream(bool, `optional`): Enable server-sent events - (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 + (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 # pylint: disable=line-too-long the result will back partially[qwen-turbo,bailian-v1]. max_length(int, `optional`): The maximum length of tokens to generate. The token count of your prompt plus max_length @@ -264,58 +306,64 @@ async def call( stream is True, return AsyncGenerator, otherwise MultiModalConversationResponse. """ if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") task_group, _ = _get_task_group_and_task(__name__) - input = {} + input = {} # pylint: disable=redefined-builtin msg_copy = None if messages is not None and messages: msg_copy = copy.deepcopy(messages) has_upload = cls._preprocess_messages(model, msg_copy, api_key) if has_upload: - headers = kwargs.pop('headers', {}) - headers['X-DashScope-OssResourceResolve'] = 'enable' - kwargs['headers'] = headers + headers = kwargs.pop("headers", {}) + headers["X-DashScope-OssResourceResolve"] = "enable" + kwargs["headers"] = headers if text is not None and text: - input.update({'text': text}) + input.update({"text": text}) if voice is not None and voice: - input.update({'voice': voice}) + input.update({"voice": voice}) if language_type is not None and language_type: - input.update({'language_type': language_type}) + input.update({"language_type": language_type}) if msg_copy is not None: - input.update({'messages': msg_copy}) + input.update({"messages": msg_copy}) # type: ignore # Check if we need to merge incremental output - is_incremental_output = kwargs.get('incremental_output', None) + is_incremental_output = kwargs.get("incremental_output", None) to_merge_incremental_output = False - is_stream = kwargs.get('stream', False) - if (ParamUtil.should_modify_incremental_output(model) and - is_stream and is_incremental_output is not None and is_incremental_output is False): + is_stream = kwargs.get("stream", False) + if ( + ParamUtil.should_modify_incremental_output(model) + and is_stream + and is_incremental_output is not None + and is_incremental_output is False + ): to_merge_incremental_output = True - kwargs['incremental_output'] = True + kwargs["incremental_output"] = True # Pass incremental_to_full flag via headers user-agent - if 'headers' not in kwargs: - kwargs['headers'] = {} - flag = '1' if to_merge_incremental_output else '0' - kwargs['headers']['user-agent'] = ( - kwargs['headers'].get('user-agent', '') + - f'; incremental_to_full/{flag}' + if "headers" not in kwargs: + kwargs["headers"] = {} + flag = "1" if to_merge_incremental_output else "0" + kwargs["headers"]["user-agent"] = ( + kwargs["headers"].get("user-agent", "") + + f"; incremental_to_full/{flag}" ) - response = await super().call(model=model, - task_group=task_group, - task=AioMultiModalConversation.task, - function=AioMultiModalConversation.function, - api_key=api_key, - input=input, - workspace=workspace, - **kwargs) + response = await super().call( + model=model, + task_group=task_group, + task=AioMultiModalConversation.task, + function=AioMultiModalConversation.function, + api_key=api_key, + input=input, + workspace=workspace, + **kwargs, + ) if is_stream: if to_merge_incremental_output: # Extract n parameter for merge logic - n = kwargs.get('n', 1) + n = kwargs.get("n", 1) return cls._merge_multimodal_response(response, n) else: return cls._stream_responses(response) @@ -323,53 +371,74 @@ async def call( return MultiModalConversationResponse.from_api_response(response) @classmethod - def _preprocess_messages(cls, model: str, messages: List[dict], - api_key: str): + def _preprocess_messages( + cls, + model: str, + messages: List[dict], + api_key: str, + ): """ - messages = [ - { - "role": "user", - "content": [ - {"image": ""}, - {"text": ""}, - ] - } - ] + messages = [ + { + "role": "user", + "content": [ + {"image": ""}, + {"text": ""}, + ] + } + ] """ has_upload = False upload_certificate = None for message in messages: - content = message['content'] + content = message["content"] for elem in content: - if not isinstance(elem, - (int, float, bool, str, bytes, bytearray)): + if not isinstance( + elem, + (int, float, bool, str, bytes, bytearray), + ): is_upload, upload_certificate = preprocess_message_element( - model, elem, api_key, upload_certificate) + model, + elem, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload and not has_upload: has_upload = True return has_upload @classmethod - async def _stream_responses(cls, response) -> AsyncGenerator[MultiModalConversationResponse, None]: - """Convert async response stream to MultiModalConversationResponse stream.""" + async def _stream_responses( + cls, + response, + ) -> AsyncGenerator[MultiModalConversationResponse, None]: + """Convert async response stream to MultiModalConversationResponse stream.""" # noqa: E501 # Type hint: when stream=True, response is actually an AsyncIterable async for rsp in response: # type: ignore yield MultiModalConversationResponse.from_api_response(rsp) @classmethod - async def _merge_multimodal_response(cls, response, n=1) -> AsyncGenerator[MultiModalConversationResponse, None]: + async def _merge_multimodal_response( + cls, + response, + n=1, + ) -> AsyncGenerator[MultiModalConversationResponse, None]: """Async version of merge incremental response chunks.""" accumulated_data = {} async for rsp in response: - parsed_response = MultiModalConversationResponse.from_api_response(rsp) - result = merge_multimodal_single_response(parsed_response, accumulated_data, n) + parsed_response = MultiModalConversationResponse.from_api_response( + rsp, + ) + result = merge_multimodal_single_response( + parsed_response, + accumulated_data, + n, + ) if result is True: yield parsed_response elif isinstance(result, list): # Multiple responses to yield (for n>1 non-stop cases) for resp in result: yield resp - - diff --git a/dashscope/aigc/video_synthesis.py b/dashscope/aigc/video_synthesis.py index 085376c..008faec 100644 --- a/dashscope/aigc/video_synthesis.py +++ b/dashscope/aigc/video_synthesis.py @@ -1,9 +1,12 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict, Union, List -from dashscope.api_entities.dashscope_response import (DashScopeAPIResponse, - VideoSynthesisResponse) +from dashscope.api_entities.dashscope_response import ( + DashScopeAPIResponse, + VideoSynthesisResponse, +) from dashscope.client.base_api import BaseAsyncApi, BaseAsyncAioApi from dashscope.common.constants import PROMPT, REFERENCE_VIDEO_URLS from dashscope.common.utils import _get_task_group_and_task @@ -11,45 +14,49 @@ class VideoSynthesis(BaseAsyncApi): - task = 'video-generation' + task = "video-generation" """API for video synthesis. """ + class Models: """@deprecated, use wanx2.1-t2v-plus instead""" - wanx_txt2video_pro = 'wanx-txt2video-pro' + + wanx_txt2video_pro = "wanx-txt2video-pro" """@deprecated, use wanx2.1-i2v-plus instead""" - wanx_img2video_pro = 'wanx-img2video-pro' + wanx_img2video_pro = "wanx-img2video-pro" - wanx_2_1_t2v_turbo = 'wanx2.1-t2v-turbo' - wanx_2_1_t2v_plus = 'wanx2.1-t2v-plus' + wanx_2_1_t2v_turbo = "wanx2.1-t2v-turbo" + wanx_2_1_t2v_plus = "wanx2.1-t2v-plus" - wanx_2_1_i2v_plus = 'wanx2.1-i2v-plus' - wanx_2_1_i2v_turbo = 'wanx2.1-i2v-turbo' + wanx_2_1_i2v_plus = "wanx2.1-i2v-plus" + wanx_2_1_i2v_turbo = "wanx2.1-i2v-turbo" - wanx_2_1_kf2v_plus = 'wanx2.1-kf2v-plus' - wanx_kf2v = 'wanx-kf2v' + wanx_2_1_kf2v_plus = "wanx2.1-kf2v-plus" + wanx_kf2v = "wanx-kf2v" @classmethod - def call(cls, - model: str, - prompt: Any = None, - # """@deprecated, use prompt_extend in parameters """ - extend_prompt: bool = True, - negative_prompt: str = None, - template: str = None, - img_url: str = None, - audio_url: str = None, - reference_video_urls: List[str] = None, - reference_video_description: List[str] = None, - api_key: str = None, - extra_input: Dict = None, - workspace: str = None, - task: str = None, - head_frame: str = None, - tail_frame: str = None, - first_frame_url: str = None, - last_frame_url: str = None, - **kwargs) -> VideoSynthesisResponse: + def call( # type: ignore[override] + cls, + model: str, + prompt: Any = None, + # """@deprecated, use prompt_extend in parameters """ + extend_prompt: bool = True, + negative_prompt: str = None, + template: str = None, + img_url: str = None, + audio_url: str = None, + reference_video_urls: List[str] = None, + reference_video_description: List[str] = None, + api_key: str = None, + extra_input: Dict = None, + workspace: str = None, + task: str = None, + head_frame: str = None, + tail_frame: str = None, + first_frame_url: str = None, + last_frame_url: str = None, + **kwargs, + ) -> VideoSynthesisResponse: """Call video synthesis service and get result. Args: @@ -58,10 +65,10 @@ def call(cls, extend_prompt (bool): @deprecated, use prompt_extend in parameters negative_prompt (str): The negative prompt is the opposite of the prompt meaning. template (str): LoRa input, such as gufeng, katong, etc. - img_url (str): The input image url, Generate the URL of the image referenced by the video. + img_url (str): The input image url, Generate the URL of the image referenced by the video. # pylint: disable=line-too-long audio_url (str): The input audio url - reference_video_urls (List[str]): list of character reference video file urls uploaded by the user - reference_video_description (List[str]): For the description information of the picture and sound of the reference video, corresponding to ref video, it needs to be in the order of the url. If the quantity is different, an error will be reported + reference_video_urls (List[str]): list of character reference video file urls uploaded by the user # pylint: disable=line-too-long + reference_video_description (List[str]): For the description information of the picture and sound of the reference video, corresponding to ref video, it needs to be in the order of the url. If the quantity is different, an error will be reported # pylint: disable=line-too-long api_key (str, optional): The api api_key. Defaults to None. workspace (str): The dashscope workspace id. extra_input (Dict): The extra input parameters. @@ -70,8 +77,11 @@ def call(cls, last_frame_url (str): The URL of the last frame image for generating the video. **kwargs: size(str, `optional`): The output video size(width*height). - duration(int, optional): The duration. Duration of video generation. The default value is 5, in seconds. - seed(int, optional): The seed. The random seed for video generation. The default value is 5. + duration( + int, + optional + ): The duration. Duration of video generation. The default value is 5, in seconds. # noqa: E501 + seed(int, optional): The seed. The random seed for video generation. The default value is 5. # noqa: E501 # pylint: disable=line-too-long Raises: InputRequired: The prompt cannot be empty. @@ -79,108 +89,171 @@ def call(cls, Returns: VideoSynthesisResponse: The video synthesis result. """ - return super().call(model, - prompt, - img_url=img_url, - audio_url=audio_url, - reference_video_urls=reference_video_urls, - reference_video_description=reference_video_description, - api_key=api_key, - extend_prompt=extend_prompt, - negative_prompt=negative_prompt, - template=template, - workspace=workspace, - extra_input=extra_input, - task=task, - head_frame=head_frame, - tail_frame=tail_frame, - first_frame_url=first_frame_url, - last_frame_url=last_frame_url, - **kwargs) + return super().call( # type: ignore[return-value] + model, + prompt, + img_url=img_url, + audio_url=audio_url, + reference_video_urls=reference_video_urls, + reference_video_description=reference_video_description, + api_key=api_key, + extend_prompt=extend_prompt, + negative_prompt=negative_prompt, + template=template, + workspace=workspace, + extra_input=extra_input, + task=task, + head_frame=head_frame, + tail_frame=tail_frame, + first_frame_url=first_frame_url, + last_frame_url=last_frame_url, + **kwargs, + ) @classmethod - def _get_input(cls, - model: str, - prompt: Any = None, - img_url: str = None, - audio_url: str = None, - reference_video_urls: List[str] = None, - reference_video_description: List[str] = None, - # """@deprecated, use prompt_extend in parameters """ - extend_prompt: bool = True, - negative_prompt: str = None, - template: str = None, - api_key: str = None, - extra_input: Dict = None, - task: str = None, - function: str = None, - head_frame: str = None, - tail_frame: str = None, - first_frame_url: str = None, - last_frame_url: str = None, - **kwargs): - - inputs = {PROMPT: prompt, 'extend_prompt': extend_prompt} + # pylint: disable=too-many-statements + def _get_input( # pylint: disable=too-many-branches + cls, + model: str, + prompt: Any = None, + img_url: str = None, + audio_url: str = None, + reference_video_urls: List[str] = None, + reference_video_description: List[str] = None, + # """@deprecated, use prompt_extend in parameters """ + extend_prompt: bool = True, + negative_prompt: str = None, + template: str = None, + api_key: str = None, + extra_input: Dict = None, + task: str = None, + function: str = None, + head_frame: str = None, + tail_frame: str = None, + first_frame_url: str = None, + last_frame_url: str = None, + **kwargs, + ): + inputs = {PROMPT: prompt, "extend_prompt": extend_prompt} if negative_prompt: - inputs['negative_prompt'] = negative_prompt + inputs["negative_prompt"] = negative_prompt if template: - inputs['template'] = template + inputs["template"] = template if function: - inputs['function'] = function + inputs["function"] = function if reference_video_description: - inputs['reference_video_description'] = reference_video_description + inputs["reference_video_description"] = reference_video_description has_upload = False upload_certificate = None if img_url is not None and img_url: - is_upload, res_img_url, upload_certificate = check_and_upload_local( - model, img_url, api_key, upload_certificate) + ( + is_upload, + res_img_url, + upload_certificate, + ) = check_and_upload_local( + model, + img_url, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload: has_upload = True - inputs['img_url'] = res_img_url + inputs["img_url"] = res_img_url if audio_url is not None and audio_url: - is_upload, res_audio_url, upload_certificate = check_and_upload_local( - model, audio_url, api_key, upload_certificate) + ( + is_upload, + res_audio_url, + upload_certificate, + ) = check_and_upload_local( + model, + audio_url, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload: has_upload = True - inputs['audio_url'] = res_audio_url + inputs["audio_url"] = res_audio_url if head_frame is not None and head_frame: - is_upload, res_head_frame, upload_certificate = check_and_upload_local( - model, head_frame, api_key, upload_certificate) + ( + is_upload, + res_head_frame, + upload_certificate, + ) = check_and_upload_local( + model, + head_frame, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload: has_upload = True - inputs['head_frame'] = res_head_frame + inputs["head_frame"] = res_head_frame if tail_frame is not None and tail_frame: - is_upload, res_tail_frame, upload_certificate = check_and_upload_local( - model, tail_frame, api_key, upload_certificate) + ( + is_upload, + res_tail_frame, + upload_certificate, + ) = check_and_upload_local( + model, + tail_frame, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload: has_upload = True - inputs['tail_frame'] = res_tail_frame + inputs["tail_frame"] = res_tail_frame if first_frame_url is not None and first_frame_url: - is_upload, res_first_frame_url, upload_certificate = check_and_upload_local( - model, first_frame_url, api_key, upload_certificate) + ( + is_upload, + res_first_frame_url, + upload_certificate, + ) = check_and_upload_local( + model, + first_frame_url, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload: has_upload = True - inputs['first_frame_url'] = res_first_frame_url + inputs["first_frame_url"] = res_first_frame_url if last_frame_url is not None and last_frame_url: - is_upload, res_last_frame_url, upload_certificate = check_and_upload_local( - model, last_frame_url, api_key, upload_certificate) + ( + is_upload, + res_last_frame_url, + upload_certificate, + ) = check_and_upload_local( + model, + last_frame_url, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload: has_upload = True - inputs['last_frame_url'] = res_last_frame_url + inputs["last_frame_url"] = res_last_frame_url - if (reference_video_urls is not None - and reference_video_urls and len(reference_video_urls) > 0): + if ( + reference_video_urls is not None + and reference_video_urls + and len(reference_video_urls) > 0 + ): new_videos = [] for video in reference_video_urls: - is_upload, new_video, upload_certificate = check_and_upload_local( - model, video, api_key, upload_certificate) + ( + is_upload, + new_video, + upload_certificate, + ) = check_and_upload_local( + model, + video, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload: has_upload = True new_videos.append(new_video) @@ -189,38 +262,41 @@ def _get_input(cls, if extra_input is not None and extra_input: inputs = {**inputs, **extra_input} if has_upload: - headers = kwargs.pop('headers', {}) - headers['X-DashScope-OssResourceResolve'] = 'enable' - kwargs['headers'] = headers + headers = kwargs.pop("headers", {}) + headers["X-DashScope-OssResourceResolve"] = "enable" + kwargs["headers"] = headers if task is None: task = VideoSynthesis.task - if model is not None and model and 'kf2v' in model: - task = 'image2video' + if model is not None and model and "kf2v" in model: + task = "image2video" return inputs, kwargs, task @classmethod - def async_call(cls, - model: str, - prompt: Any = None, - img_url: str = None, - audio_url: str = None, - reference_video_urls: List[str] = None, - reference_video_description: List[str] = None, - # """@deprecated, use prompt_extend in parameters """ - extend_prompt: bool = True, - negative_prompt: str = None, - template: str = None, - api_key: str = None, - extra_input: Dict = None, - workspace: str = None, - task: str = None, - head_frame: str = None, - tail_frame: str = None, - first_frame_url: str = None, - last_frame_url: str = None, - **kwargs) -> VideoSynthesisResponse: + # type: ignore[override] + def async_call( # pylint: disable=arguments-renamed # type: ignore[override] # noqa: E501 + cls, + model: str, + prompt: Any = None, + img_url: str = None, + audio_url: str = None, + reference_video_urls: List[str] = None, + reference_video_description: List[str] = None, + # """@deprecated, use prompt_extend in parameters """ + extend_prompt: bool = True, + negative_prompt: str = None, + template: str = None, + api_key: str = None, + extra_input: Dict = None, + workspace: str = None, + task: str = None, + head_frame: str = None, + tail_frame: str = None, + first_frame_url: str = None, + last_frame_url: str = None, + **kwargs, + ) -> VideoSynthesisResponse: """Create a video synthesis task, and return task information. Args: @@ -229,10 +305,10 @@ def async_call(cls, extend_prompt (bool): @deprecated, use prompt_extend in parameters negative_prompt (str): The negative prompt is the opposite of the prompt meaning. template (str): LoRa input, such as gufeng, katong, etc. - img_url (str): The input image url, Generate the URL of the image referenced by the video. + img_url (str): The input image url, Generate the URL of the image referenced by the video. # pylint: disable=line-too-long audio_url (str): The input audio url. - reference_video_urls (List[str]): list of character reference video file urls uploaded by the user - reference_video_description (List[str]): For the description information of the picture and sound of the reference video, corresponding to ref video, it needs to be in the order of the url. If the quantity is different, an error will be reported + reference_video_urls (List[str]): list of character reference video file urls uploaded by the user # pylint: disable=line-too-long + reference_video_description (List[str]): For the description information of the picture and sound of the reference video, corresponding to ref video, it needs to be in the order of the url. If the quantity is different, an error will be reported # pylint: disable=line-too-long api_key (str, optional): The api api_key. Defaults to None. workspace (str): The dashscope workspace id. extra_input (Dict): The extra input parameters. @@ -241,8 +317,11 @@ def async_call(cls, last_frame_url (str): The URL of the last frame image for generating the video. **kwargs: size(str, `optional`): The output video size(width*height). - duration(int, optional): The duration. Duration of video generation. The default value is 5, in seconds. - seed(int, optional): The seed. The random seed for video generation. The default value is 5. + duration( + int, + optional + ): The duration. Duration of video generation. The default value is 5, in seconds. # noqa: E501 + seed(int, optional): The seed. The random seed for video generation. The default value is 5. # noqa: E501 # pylint: disable=line-too-long Raises: InputRequired: The prompt cannot be empty. @@ -254,10 +333,25 @@ def async_call(cls, task_group, function = _get_task_group_and_task(__name__) inputs, kwargs, task = cls._get_input( - model, prompt, img_url, audio_url, reference_video_urls, reference_video_description, - extend_prompt, negative_prompt, template, api_key, - extra_input, task, function, head_frame, tail_frame, - first_frame_url, last_frame_url, **kwargs) + model, + prompt, + img_url, + audio_url, + reference_video_urls, + reference_video_description, + extend_prompt, + negative_prompt, + template, + api_key, + extra_input, + task, + function, + head_frame, + tail_frame, + first_frame_url, + last_frame_url, + **kwargs, + ) response = super().async_call( model=model, @@ -267,14 +361,17 @@ def async_call(cls, api_key=api_key, input=inputs, workspace=workspace, - **kwargs) + **kwargs, + ) return VideoSynthesisResponse.from_api_response(response) @classmethod - def fetch(cls, - task: Union[str, VideoSynthesisResponse], - api_key: str = None, - workspace: str = None) -> VideoSynthesisResponse: + def fetch( # type: ignore[override] + cls, + task: Union[str, VideoSynthesisResponse], + api_key: str = None, + workspace: str = None, + ) -> VideoSynthesisResponse: """Fetch video synthesis task status or result. Args: @@ -290,10 +387,12 @@ def fetch(cls, return VideoSynthesisResponse.from_api_response(response) @classmethod - def wait(cls, - task: Union[str, VideoSynthesisResponse], - api_key: str = None, - workspace: str = None) -> VideoSynthesisResponse: + def wait( # type: ignore[override] + cls, + task: Union[str, VideoSynthesisResponse], + api_key: str = None, + workspace: str = None, + ) -> VideoSynthesisResponse: """Wait for video synthesis task to complete, and return the result. Args: @@ -309,10 +408,12 @@ def wait(cls, return VideoSynthesisResponse.from_api_response(response) @classmethod - def cancel(cls, - task: Union[str, VideoSynthesisResponse], - api_key: str = None, - workspace: str = None) -> DashScopeAPIResponse: + def cancel( # type: ignore[override] + cls, + task: Union[str, VideoSynthesisResponse], + api_key: str = None, + workspace: str = None, + ) -> DashScopeAPIResponse: """Cancel video synthesis task. Only tasks whose status is PENDING can be canceled. @@ -328,18 +429,20 @@ def cancel(cls, return super().cancel(task, api_key, workspace=workspace) @classmethod - def list(cls, - start_time: str = None, - end_time: str = None, - model_name: str = None, - api_key_id: str = None, - region: str = None, - status: str = None, - page_no: int = 1, - page_size: int = 10, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def list( + cls, + start_time: str = None, + end_time: str = None, + model_name: str = None, + api_key_id: str = None, + region: str = None, + status: str = None, + page_no: int = 1, + page_size: int = 10, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """List async tasks. Args: @@ -361,40 +464,47 @@ def list(cls, Returns: DashScopeAPIResponse: The response data. """ - return super().list(start_time=start_time, - end_time=end_time, - model_name=model_name, - api_key_id=api_key_id, - region=region, - status=status, - page_no=page_no, - page_size=page_size, - api_key=api_key, - workspace=workspace, - **kwargs) + return super().list( + start_time=start_time, + end_time=end_time, + model_name=model_name, + api_key_id=api_key_id, + region=region, + status=status, + page_no=page_no, + page_size=page_size, + api_key=api_key, + workspace=workspace, + **kwargs, + ) + class AioVideoSynthesis(BaseAsyncAioApi): + # type: ignore[override] @classmethod - async def call(cls, - model: str, - prompt: Any = None, - img_url: str = None, - audio_url: str = None, - reference_video_urls: List[str] = None, - reference_video_description: List[str] = None, - # """@deprecated, use prompt_extend in parameters """ - extend_prompt: bool = True, - negative_prompt: str = None, - template: str = None, - api_key: str = None, - extra_input: Dict = None, - workspace: str = None, - task: str = None, - head_frame: str = None, - tail_frame: str = None, - first_frame_url: str = None, - last_frame_url: str = None, - **kwargs) -> VideoSynthesisResponse: + async def call( # type: ignore[override] # pylint: disable=arguments-renamed # noqa: E501 + # type: ignore[override] + cls, + model: str, + prompt: Any = None, + img_url: str = None, + audio_url: str = None, + reference_video_urls: List[str] = None, + reference_video_description: List[str] = None, + # """@deprecated, use prompt_extend in parameters """ + extend_prompt: bool = True, + negative_prompt: str = None, + template: str = None, + api_key: str = None, + extra_input: Dict = None, + workspace: str = None, + task: str = None, + head_frame: str = None, + tail_frame: str = None, + first_frame_url: str = None, + last_frame_url: str = None, + **kwargs, + ) -> VideoSynthesisResponse: """Call video synthesis service and get result. Args: @@ -403,10 +513,10 @@ async def call(cls, extend_prompt (bool): @deprecated, use prompt_extend in parameters negative_prompt (str): The negative prompt is the opposite of the prompt meaning. template (str): LoRa input, such as gufeng, katong, etc. - img_url (str): The input image url, Generate the URL of the image referenced by the video. + img_url (str): The input image url, Generate the URL of the image referenced by the video. # pylint: disable=line-too-long audio_url (str): The input audio url. - reference_video_urls (List[str]): list of character reference video file urls uploaded by the user - reference_video_description (List[str]): For the description information of the picture and sound of the reference video, corresponding to ref video, it needs to be in the order of the url. If the quantity is different, an error will be reported + reference_video_urls (List[str]): list of character reference video file urls uploaded by the user # pylint: disable=line-too-long + reference_video_description (List[str]): For the description information of the picture and sound of the reference video, corresponding to ref video, it needs to be in the order of the url. If the quantity is different, an error will be reported # pylint: disable=line-too-long api_key (str, optional): The api api_key. Defaults to None. workspace (str): The dashscope workspace id. extra_input (Dict): The extra input parameters. @@ -415,8 +525,11 @@ async def call(cls, last_frame_url (str): The URL of the last frame image for generating the video. **kwargs: size(str, `optional`): The output video size(width*height). - duration(int, optional): The duration. Duration of video generation. The default value is 5, in seconds. - seed(int, optional): The seed. The random seed for video generation. The default value is 5. + duration( + int, + optional + ): The duration. Duration of video generation. The default value is 5, in seconds. # noqa: E501 + seed(int, optional): The seed. The random seed for video generation. The default value is 5. # noqa: E501 # pylint: disable=line-too-long Raises: InputRequired: The prompt cannot be empty. @@ -425,35 +538,65 @@ async def call(cls, VideoSynthesisResponse: The video synthesis result. """ task_group, f = _get_task_group_and_task(__name__) + # pylint: disable=protected-access inputs, kwargs, task = VideoSynthesis._get_input( - model, prompt, img_url, audio_url, reference_video_urls, reference_video_description, - extend_prompt, negative_prompt, template, api_key, - extra_input, task, f, head_frame, tail_frame, - first_frame_url, last_frame_url, **kwargs) - response = await super().call(model, inputs, task_group, task, f, api_key, workspace, **kwargs) + model, + # pylint: disable=protected-access + prompt, + img_url, + audio_url, + reference_video_urls, + reference_video_description, + extend_prompt, + negative_prompt, + template, + api_key, + extra_input, + task, + f, + head_frame, + tail_frame, + first_frame_url, + last_frame_url, + **kwargs, + ) + response = await super().call( + model, + inputs, + task_group, + task, + f, + api_key, + workspace, + **kwargs, + ) return VideoSynthesisResponse.from_api_response(response) + # type: ignore[override] + @classmethod - async def async_call(cls, - model: str, - prompt: Any = None, - img_url: str = None, - audio_url: str = None, - reference_video_urls: List[str] = None, - reference_video_description: List[str] = None, - # """@deprecated, use prompt_extend in parameters """ - extend_prompt: bool = True, - negative_prompt: str = None, - template: str = None, - api_key: str = None, - extra_input: Dict = None, - workspace: str = None, - task: str = None, - head_frame: str = None, - tail_frame: str = None, - first_frame_url: str = None, - last_frame_url: str = None, - **kwargs) -> VideoSynthesisResponse: + async def async_call( # type: ignore[override] # pylint: disable=arguments-renamed # noqa: E501 + cls, + model: str, + prompt: Any = None, + img_url: str = None, + audio_url: str = None, + reference_video_urls: List[str] = None, + reference_video_description: List[str] = None, + # """@deprecated, use prompt_extend in parameters """ + extend_prompt: bool = True, + negative_prompt: str = None, + template: str = None, + api_key: str = None, + extra_input: Dict = None, + workspace: str = None, + task: str = None, + head_frame: str = None, + tail_frame: str = None, + first_frame_url: str = None, + last_frame_url: str = None, + **kwargs, + ) -> VideoSynthesisResponse: """Create a video synthesis task, and return task information. Args: @@ -462,10 +605,10 @@ async def async_call(cls, extend_prompt (bool): @deprecated, use prompt_extend in parameters negative_prompt (str): The negative prompt is the opposite of the prompt meaning. template (str): LoRa input, such as gufeng, katong, etc. - img_url (str): The input image url, Generate the URL of the image referenced by the video. + img_url (str): The input image url, Generate the URL of the image referenced by the video. # pylint: disable=line-too-long audio_url (str): The input audio url. - reference_video_urls (List[str]): list of character reference video file urls uploaded by the user - reference_video_description (List[str]): For the description information of the picture and sound of the reference video, corresponding to ref video, it needs to be in the order of the url. If the quantity is different, an error will be reported + reference_video_urls (List[str]): list of character reference video file urls uploaded by the user # pylint: disable=line-too-long + reference_video_description (List[str]): For the description information of the picture and sound of the reference video, corresponding to ref video, it needs to be in the order of the url. If the quantity is different, an error will be reported # pylint: disable=line-too-long api_key (str, optional): The api api_key. Defaults to None. workspace (str): The dashscope workspace id. extra_input (Dict): The extra input parameters. @@ -474,8 +617,11 @@ async def async_call(cls, last_frame_url (str): The URL of the last frame image for generating the video. **kwargs: size(str, `optional`): The output video size(width*height). - duration(int, optional): The duration. Duration of video generation. The default value is 5, in seconds. - seed(int, optional): The seed. The random seed for video generation. The default value is 5. + duration( + int, + optional + ): The duration. Duration of video generation. The default value is 5, in seconds. # noqa: E501 + seed(int, optional): The seed. The random seed for video generation. The default value is 5. # noqa: E501 # pylint: disable=line-too-long Raises: InputRequired: The prompt cannot be empty. @@ -486,11 +632,28 @@ async def async_call(cls, """ task_group, function = _get_task_group_and_task(__name__) + # pylint: disable=protected-access inputs, kwargs, task = VideoSynthesis._get_input( - model, prompt, img_url, audio_url, reference_video_urls, reference_video_description, - extend_prompt, negative_prompt, template, api_key, - extra_input, task, function, head_frame, tail_frame, - first_frame_url, last_frame_url, **kwargs) + model, + # pylint: disable=protected-access + prompt, + img_url, + audio_url, + reference_video_urls, + reference_video_description, + extend_prompt, + negative_prompt, + template, + api_key, + extra_input, + task, + function, + head_frame, + tail_frame, + first_frame_url, + last_frame_url, + **kwargs, + ) response = await super().async_call( model=model, @@ -500,15 +663,18 @@ async def async_call(cls, api_key=api_key, input=inputs, workspace=workspace, - **kwargs) + **kwargs, + ) return VideoSynthesisResponse.from_api_response(response) @classmethod - async def fetch(cls, - task: Union[str, VideoSynthesisResponse], - api_key: str = None, - workspace: str = None, - **kwargs) -> VideoSynthesisResponse: + async def fetch( + cls, + task: Union[str, VideoSynthesisResponse], # type: ignore[override] + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> VideoSynthesisResponse: """Fetch video synthesis task status or result. Args: @@ -520,15 +686,21 @@ async def fetch(cls, Returns: VideoSynthesisResponse: The task status or result. """ - response = await super().fetch(task, api_key=api_key, workspace=workspace) + response = await super().fetch( + task, + api_key=api_key, + workspace=workspace, + ) return VideoSynthesisResponse.from_api_response(response) @classmethod - async def wait(cls, - task: Union[str, VideoSynthesisResponse], - api_key: str = None, - workspace: str = None, - **kwargs) -> VideoSynthesisResponse: + async def wait( + cls, + task: Union[str, VideoSynthesisResponse], # type: ignore[override] + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> VideoSynthesisResponse: """Wait for video synthesis task to complete, and return the result. Args: @@ -544,11 +716,13 @@ async def wait(cls, return VideoSynthesisResponse.from_api_response(response) @classmethod - async def cancel(cls, - task: Union[str, VideoSynthesisResponse], - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + async def cancel( + cls, + task: Union[str, VideoSynthesisResponse], # type: ignore[override] + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Cancel video synthesis task. Only tasks whose status is PENDING can be canceled. @@ -564,18 +738,20 @@ async def cancel(cls, return await super().cancel(task, api_key, workspace=workspace) @classmethod - async def list(cls, - start_time: str = None, - end_time: str = None, - model_name: str = None, - api_key_id: str = None, - region: str = None, - status: str = None, - page_no: int = 1, - page_size: int = 10, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + async def list( + cls, + start_time: str = None, + end_time: str = None, + model_name: str = None, + api_key_id: str = None, + region: str = None, + status: str = None, + page_no: int = 1, + page_size: int = 10, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """List async tasks. Args: @@ -597,14 +773,16 @@ async def list(cls, Returns: DashScopeAPIResponse: The response data. """ - return await super().list(start_time=start_time, - end_time=end_time, - model_name=model_name, - api_key_id=api_key_id, - region=region, - status=status, - page_no=page_no, - page_size=page_size, - api_key=api_key, - workspace=workspace, - **kwargs) \ No newline at end of file + return await super().list( + start_time=start_time, + end_time=end_time, + model_name=model_name, + api_key_id=api_key_id, + region=region, + status=status, + page_no=page_no, + page_size=page_size, + api_key=api_key, + workspace=workspace, + **kwargs, + ) diff --git a/dashscope/api_entities/aiohttp_request.py b/dashscope/api_entities/aiohttp_request.py index 1f0bcd0..75b3965 100644 --- a/dashscope/api_entities/aiohttp_request.py +++ b/dashscope/api_entities/aiohttp_request.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -7,24 +8,29 @@ from dashscope.api_entities.base_request import AioBaseRequest from dashscope.api_entities.dashscope_response import DashScopeAPIResponse -from dashscope.common.constants import (DEFAULT_REQUEST_TIMEOUT_SECONDS, - SSE_CONTENT_TYPE, HTTPMethod) +from dashscope.common.constants import ( + DEFAULT_REQUEST_TIMEOUT_SECONDS, + SSE_CONTENT_TYPE, + HTTPMethod, +) from dashscope.common.error import UnsupportedHTTPMethod from dashscope.common.logging import logger from dashscope.common.utils import async_to_sync class AioHttpRequest(AioBaseRequest): - def __init__(self, - url: str, - api_key: str, - http_method: str, - stream: bool = True, - async_request: bool = False, - query: bool = False, - timeout: int = DEFAULT_REQUEST_TIMEOUT_SECONDS, - task_id: str = None, - user_agent: str = '') -> None: + def __init__( + self, + url: str, + api_key: str, + http_method: str, + stream: bool = True, + async_request: bool = False, + query: bool = False, + timeout: int = DEFAULT_REQUEST_TIMEOUT_SECONDS, + task_id: str = None, + user_agent: str = "", + ) -> None: """HttpSSERequest, processing http server sent event stream. Args: @@ -42,33 +48,33 @@ def __init__(self, self.url = url self.async_request = async_request self.headers = { - 'Accept': 'application/json', - 'Authorization': 'Bearer %s' % api_key, - 'Cache-Control': 'no-cache', + "Accept": "application/json", + "Authorization": f"Bearer {api_key}", + "Cache-Control": "no-cache", **self.headers, } self.query = query if self.async_request and self.query is False: self.headers = { - 'X-DashScope-Async': 'enable', + "X-DashScope-Async": "enable", **self.headers, } self.method = http_method if self.method == HTTPMethod.POST: - self.headers['Content-Type'] = 'application/json' + self.headers["Content-Type"] = "application/json" self.stream = stream if self.stream: - self.headers['Accept'] = SSE_CONTENT_TYPE - self.headers['X-Accel-Buffering'] = 'no' - self.headers['X-DashScope-SSE'] = 'enable' + self.headers["Accept"] = SSE_CONTENT_TYPE + self.headers["X-Accel-Buffering"] = "no" + self.headers["X-DashScope-SSE"] = "enable" if self.query: - self.url = self.url.replace('api', 'api-task') - self.url += '%s' % task_id + self.url = self.url.replace("api", "api-task") + self.url += f"{task_id}" if timeout is None: self.timeout = DEFAULT_REQUEST_TIMEOUT_SECONDS else: - self.timeout = timeout + self.timeout = timeout # type: ignore[has-type] def add_header(self, key, value): self.headers[key] = value @@ -106,57 +112,72 @@ async def _handle_stream(self, response): status_code = HTTPStatus.BAD_REQUEST async for line in response.content: if line: - line = line.decode('utf8') - line = line.rstrip('\n').rstrip('\r') - if line.startswith('event:error'): + line = line.decode("utf8") + line = line.rstrip("\n").rstrip("\r") + if line.startswith("event:error"): is_error = True - elif line.startswith('status:'): - status_code = line[len('status:'):] + elif line.startswith("status:"): + status_code = line[len("status:") :] status_code = int(status_code.strip()) - elif line.startswith('data:'): - line = line[len('data:'):] + elif line.startswith("data:"): + line = line[len("data:") :] yield (is_error, status_code, line) if is_error: break else: continue # ignore heartbeat... - async def _handle_response(self, response: aiohttp.ClientResponse): - request_id = '' - if (response.status == HTTPStatus.OK and self.stream - and SSE_CONTENT_TYPE in response.content_type): + # pylint: disable=too-many-statements + async def _handle_response( # pylint: disable=too-many-branches + self, + response: aiohttp.ClientResponse, + ): + request_id = "" + if ( + response.status == HTTPStatus.OK + and self.stream + and SSE_CONTENT_TYPE in response.content_type + ): async for is_error, status_code, data in self._handle_stream( - response): + response, + ): try: output = None usage = None msg = json.loads(data) if not is_error: - if 'output' in msg: - output = msg['output'] - if 'usage' in msg: - usage = msg['usage'] - if 'request_id' in msg: - request_id = msg['request_id'] + if "output" in msg: + output = msg["output"] + if "usage" in msg: + usage = msg["usage"] + if "request_id" in msg: + request_id = msg["request_id"] except json.JSONDecodeError: yield DashScopeAPIResponse( request_id=request_id, status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - code='Unknown', - message=data) + code="Unknown", + message=data, + ) continue if is_error: - yield DashScopeAPIResponse(request_id=request_id, - status_code=status_code, - code=msg['code'], - message=msg['message']) + yield DashScopeAPIResponse( + request_id=request_id, + status_code=status_code, + code=msg["code"], + message=msg["message"], + ) else: - yield DashScopeAPIResponse(request_id=request_id, - status_code=HTTPStatus.OK, - output=output, - usage=usage) - elif (response.status == HTTPStatus.OK - and 'multipart' in response.content_type): + yield DashScopeAPIResponse( + request_id=request_id, + status_code=HTTPStatus.OK, + output=output, + usage=usage, + ) + elif ( + response.status == HTTPStatus.OK + and "multipart" in response.content_type + ): reader = aiohttp.MultipartReader.from_response(response) output = {} while True: @@ -164,76 +185,99 @@ async def _handle_response(self, response: aiohttp.ClientResponse): if part is None: break output[part.name] = await part.read() - if 'request_id' in output: - request_id = output['request_id'] - yield DashScopeAPIResponse(request_id=request_id, - status_code=HTTPStatus.OK, - output=output) + # pylint: disable=consider-using-get + if "request_id" in output: + request_id = output["request_id"] + yield DashScopeAPIResponse( + request_id=request_id, + status_code=HTTPStatus.OK, + output=output, + ) elif response.status == HTTPStatus.OK: json_content = await response.json() output = None usage = None - if 'output' in json_content and json_content['output'] is not None: - output = json_content['output'] - if 'usage' in json_content: - usage = json_content['usage'] - if 'request_id' in json_content: - request_id = json_content['request_id'] - yield DashScopeAPIResponse(request_id=request_id, - status_code=HTTPStatus.OK, - output=output, - usage=usage) + if "output" in json_content and json_content["output"] is not None: + output = json_content["output"] + if "usage" in json_content: + usage = json_content["usage"] + if "request_id" in json_content: + request_id = json_content["request_id"] + yield DashScopeAPIResponse( + request_id=request_id, + status_code=HTTPStatus.OK, + output=output, + usage=usage, + ) else: - if 'application/json' in response.content_type: + if "application/json" in response.content_type: error = await response.json() - if 'request_id' in error: - request_id = error['request_id'] - if 'message' not in error: - message = '' - logger.error('Request: %s failed, status: %s' % - (self.url, response.status)) + if "request_id" in error: + request_id = error["request_id"] + if "message" not in error: + message = "" + logger.error( + "Request: %s failed, status: %s", + self.url, + response.status, + ) else: - message = error['message'] + message = error["message"] logger.error( - 'Request: %s failed, status: %s, message: %s' % - (self.url, response.status, error['message'])) - yield DashScopeAPIResponse(request_id=request_id, - status_code=response.status, - code=error['code'], - message=message) + "Request: %s failed, status: %s, message: %s", + self.url, + response.status, + error["message"], + ) + yield DashScopeAPIResponse( + request_id=request_id, + status_code=response.status, + code=error["code"], + message=message, + ) else: msg = await response.read() - yield DashScopeAPIResponse(request_id=request_id, - status_code=response.status, - code='Unknown', - message=msg.decode('utf-8')) + yield DashScopeAPIResponse( + request_id=request_id, + status_code=response.status, + code="Unknown", + message=msg.decode("utf-8"), + ) async def _handle_request(self): try: async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=self.timeout), - headers=self.headers) as session: - logger.debug('Starting request: %s' % self.url) + timeout=aiohttp.ClientTimeout(total=self.timeout), + headers=self.headers, + ) as session: + logger.debug("Starting request: %s", self.url) if self.method == HTTPMethod.POST: is_form, obj = self.data.get_aiohttp_payload() if is_form: headers = {**self.headers, **obj.headers} - response = await session.post(url=self.url, - data=obj, - headers=headers) + response = await session.post( + url=self.url, + data=obj, + headers=headers, + ) else: - response = await session.request('POST', - url=self.url, - json=obj, - headers=self.headers) + response = await session.request( + "POST", + url=self.url, + json=obj, + headers=self.headers, + ) elif self.method == HTTPMethod.GET: - response = await session.get(url=self.url, - params=self.data.parameters, - headers=self.headers) + response = await session.get( + url=self.url, + params=self.data.parameters, + headers=self.headers, + ) else: - raise UnsupportedHTTPMethod('Unsupported http method: %s' % - self.method) - logger.debug('Response returned: %s' % self.url) + raise UnsupportedHTTPMethod( + f"Unsupported http method: {self.method}", + ) + logger.debug("Response returned: %s", self.url) async with response: async for rsp in self._handle_response(response): yield rsp diff --git a/dashscope/api_entities/api_request_data.py b/dashscope/api_entities/api_request_data.py index 173431a..d6c658a 100644 --- a/dashscope/api_entities/api_request_data.py +++ b/dashscope/api_entities/api_request_data.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -9,14 +10,14 @@ from dashscope.io.input_output import InputResolver -class ApiRequestData(): +class ApiRequestData: def __init__( self, model, task_group, task, function, - input, + input_data, form, is_binary_input, api_protocol, @@ -25,7 +26,7 @@ def __init__( self.task = task self.task_group = task_group self.function = function - self._input = input + self._input = input_data self._input_type = {} self._input_generators = {} self.parameters = {} @@ -37,8 +38,10 @@ def __init__( if api_protocol in [ApiProtocol.HTTP, ApiProtocol.HTTPS]: self._input_resolver = InputResolver(input_instance=self._input) else: - self._input_resolver = InputResolver(input_instance=self._input, - is_encode_binary=False) + self._input_resolver = InputResolver( + input_instance=self._input, + is_encode_binary=False, + ) def add_parameters(self, **params): for key, value in params.items(): @@ -56,10 +59,14 @@ def to_request_object(self) -> str: o = { k: v for k, v in self.__dict__.items() - if not (k.startswith('_') or k.startswith('task') - or k.startswith('function') or v is None) + if not ( + k.startswith("_") + or k.startswith("task") + or k.startswith("function") + or v is None + ) } - return o + return o # type: ignore[return-value] def get_aiohttp_payload(self): """Get http payload. @@ -74,11 +81,12 @@ def get_aiohttp_payload(self): form = aiohttp.FormData() for key, value in self._form.items(): form.add_field(key, value) - form.add_field('model', data['model']) - if 'input' in data: - form.add_field('input', json.dumps(data['input'])) - form.add_field('parameters', json.dumps(data['parameters'])) + form.add_field("model", data["model"]) + if "input" in data: + form.add_field("input", json.dumps(data["input"])) + form.add_field("parameters", json.dumps(data["parameters"])) return True, form() + # pylint: disable=unreachable,pointless-string-statement """ mp_writer = aiohttp.MultipartWriter('mixed') mp_writer.append('model=%s'%self.model) @@ -121,7 +129,7 @@ def get_websocket_start_data(self): data = { k: v for k, v in self.__dict__.items() - if not (k.startswith('_') or v is None) + if not (k.startswith("_") or v is None) } return data @@ -133,11 +141,11 @@ def _to_json_only_data(self) -> str: o = { k: v for k, v in self.__dict__.items() - if not (k.startswith('_') or k.startswith('param')) + if not (k.startswith("_") or k.startswith("param")) } return json.dumps(o, default=lambda o: o.__dict__) - def get_batch_binary_data(self) -> bytes: + def get_batch_binary_data(self) -> bytes: # type: ignore[return] """Get binary data. used in streaming mode none and out (input is not streaming), we send data in one package. In this case only has one field input. @@ -149,21 +157,22 @@ def get_batch_binary_data(self) -> bytes: return content def _only_parameters(self) -> str: - obj = {'model': self.model, 'parameters': self.parameters, 'input': {}} + obj = {"model": self.model, "parameters": self.parameters, "input": {}} if self.task is not None: - obj['task'] = self.task + obj["task"] = self.task if self.task_group is not None: - obj['task_group'] = self.task_group + obj["task_group"] = self.task_group if self.function is not None: - obj['function'] = self.function + obj["function"] = self.function if self.resources is not None: - obj['resources'] = self.resources - return obj + obj["resources"] = self.resources + return obj # type: ignore[return-value] def to_query_parameters(self) -> str: - query_string = '?' + # pylint: disable=not-an-iterable + query_string = "?" for key, value in self.parameters.items: - param = '%s/%s&' % (key, value) + param = f"{key}/{value}&" query_string += param query_string = query_string[0:-1] # remove last # - return urlencode(query_string) + return urlencode(query_string) # type: ignore[arg-type] diff --git a/dashscope/api_entities/api_request_factory.py b/dashscope/api_entities/api_request_factory.py index 9347a19..bc96fc0 100644 --- a/dashscope/api_entities/api_request_factory.py +++ b/dashscope/api_entities/api_request_factory.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from urllib.parse import urlencode @@ -5,134 +6,174 @@ from dashscope.api_entities.api_request_data import ApiRequestData from dashscope.api_entities.http_request import HttpRequest from dashscope.api_entities.websocket_request import WebSocketRequest -from dashscope.common.constants import (REQUEST_TIMEOUT_KEYWORD, - SERVICE_API_PATH, ApiProtocol, - HTTPMethod) +from dashscope.common.constants import ( + REQUEST_TIMEOUT_KEYWORD, + SERVICE_API_PATH, + ApiProtocol, + HTTPMethod, +) from dashscope.common.error import InputDataRequired, UnsupportedApiProtocol from dashscope.common.logging import logger from dashscope.protocol.websocket import WebsocketStreamingMode from dashscope.api_entities.encryption import Encryption + def _get_protocol_params(kwargs): - api_protocol = kwargs.pop('api_protocol', ApiProtocol.HTTPS) - ws_stream_mode = kwargs.pop('ws_stream_mode', WebsocketStreamingMode.OUT) - is_binary_input = kwargs.pop('is_binary_input', False) - http_method = kwargs.pop('http_method', HTTPMethod.POST) - stream = kwargs.get('stream', False) + api_protocol = kwargs.pop("api_protocol", ApiProtocol.HTTPS) + ws_stream_mode = kwargs.pop("ws_stream_mode", WebsocketStreamingMode.OUT) + is_binary_input = kwargs.pop("is_binary_input", False) + http_method = kwargs.pop("http_method", HTTPMethod.POST) + stream = kwargs.get("stream", False) if not stream and ws_stream_mode == WebsocketStreamingMode.OUT: ws_stream_mode = WebsocketStreamingMode.NONE - async_request = kwargs.pop('async_request', False) - query = kwargs.pop('query', False) - headers = kwargs.pop('headers', None) + async_request = kwargs.pop("async_request", False) + query = kwargs.pop("query", False) + headers = kwargs.pop("headers", None) request_timeout = kwargs.pop(REQUEST_TIMEOUT_KEYWORD, None) - form = kwargs.pop('form', None) - resources = kwargs.pop('resources', None) - base_address = kwargs.pop('base_address', None) - flattened_output = kwargs.pop('flattened_output', False) - extra_url_parameters = kwargs.pop('extra_url_parameters', None) + form = kwargs.pop("form", None) + resources = kwargs.pop("resources", None) + base_address = kwargs.pop("base_address", None) + flattened_output = kwargs.pop("flattened_output", False) + extra_url_parameters = kwargs.pop("extra_url_parameters", None) # Extract user-agent from headers if present - user_agent = '' - if headers and 'user-agent' in headers: - user_agent = headers.pop('user-agent') - - return (api_protocol, ws_stream_mode, is_binary_input, http_method, stream, - async_request, query, headers, request_timeout, form, resources, - base_address, flattened_output, extra_url_parameters, user_agent) - - -def _build_api_request(model: str, - input: object, - task_group: str, - task: str, - function: str, - api_key: str, - is_service=True, - **kwargs): - (api_protocol, ws_stream_mode, is_binary_input, http_method, stream, - async_request, query, headers, request_timeout, form, resources, - base_address, flattened_output, extra_url_parameters, - user_agent) = _get_protocol_params(kwargs) - task_id = kwargs.pop('task_id', None) - enable_encryption = kwargs.pop('enable_encryption', False) + user_agent = "" + if headers and "user-agent" in headers: + user_agent = headers.pop("user-agent") + + return ( + api_protocol, + ws_stream_mode, + is_binary_input, + http_method, + stream, + async_request, + query, + headers, + request_timeout, + form, + resources, + base_address, + flattened_output, + extra_url_parameters, + user_agent, + ) + + +def _build_api_request( # pylint: disable=too-many-branches + model: str, + input: object, # pylint: disable=redefined-builtin + task_group: str, + task: str, + function: str, + api_key: str, + is_service=True, + **kwargs, +): + ( + api_protocol, + ws_stream_mode, + is_binary_input, + http_method, + stream, + async_request, + query, + headers, + request_timeout, + form, + resources, + base_address, + flattened_output, + extra_url_parameters, + user_agent, + ) = _get_protocol_params(kwargs) + task_id = kwargs.pop("task_id", None) + enable_encryption = kwargs.pop("enable_encryption", False) encryption = None if api_protocol in [ApiProtocol.HTTP, ApiProtocol.HTTPS]: if base_address is None: base_address = dashscope.base_http_api_url - if not base_address.endswith('/'): - http_url = base_address + '/' + if not base_address.endswith("/"): + http_url = base_address + "/" else: http_url = base_address if is_service: - http_url = http_url + SERVICE_API_PATH + '/' + http_url = http_url + SERVICE_API_PATH + "/" if task_group: - http_url += '%s/' % task_group + http_url += f"{task_group}/" if task: - http_url += '%s/' % task + http_url += f"{task}/" if function: http_url += function if extra_url_parameters is not None and extra_url_parameters: - http_url += '?' + urlencode(extra_url_parameters) + http_url += "?" + urlencode(extra_url_parameters) if enable_encryption is True: encryption = Encryption() encryption.initialize() if encryption.is_valid(): - logger.debug('encryption enabled') - - request = HttpRequest(url=http_url, - api_key=api_key, - http_method=http_method, - stream=stream, - async_request=async_request, - query=query, - timeout=request_timeout, - task_id=task_id, - flattened_output=flattened_output, - encryption=encryption, - user_agent=user_agent) + logger.debug("encryption enabled") + + request = HttpRequest( + url=http_url, + api_key=api_key, + http_method=http_method, + stream=stream, + async_request=async_request, + query=query, + timeout=request_timeout, + task_id=task_id, + flattened_output=flattened_output, + encryption=encryption, + user_agent=user_agent, + ) elif api_protocol == ApiProtocol.WEBSOCKET: if base_address is not None: websocket_url = base_address else: websocket_url = dashscope.base_websocket_api_url - pre_task_id = kwargs.pop('pre_task_id', None) - request = WebSocketRequest(url=websocket_url, - api_key=api_key, - stream=stream, - ws_stream_mode=ws_stream_mode, - is_binary_input=is_binary_input, - timeout=request_timeout, - flattened_output=flattened_output, - pre_task_id=pre_task_id, - user_agent=user_agent) + pre_task_id = kwargs.pop("pre_task_id", None) + request = WebSocketRequest( + url=websocket_url, + api_key=api_key, + stream=stream, + ws_stream_mode=ws_stream_mode, + is_binary_input=is_binary_input, + timeout=request_timeout, + flattened_output=flattened_output, + pre_task_id=pre_task_id, + user_agent=user_agent, + ) else: raise UnsupportedApiProtocol( - 'Unsupported protocol: %s, support [http, https, websocket]' % - api_protocol) + f"Unsupported protocol: {api_protocol}, support [http, https, " + "websocket]", + ) if headers is not None: request.add_headers(headers=headers) if input is None and form is None: - raise InputDataRequired('There is no input data and form data') + raise InputDataRequired("There is no input data and form data") if encryption and encryption.is_valid(): input = encryption.encrypt(input) - request_data = ApiRequestData(model, - task_group=task_group, - task=task, - function=function, - input=input, - form=form, - is_binary_input=is_binary_input, - api_protocol=api_protocol) + request_data = ApiRequestData( + model, + task_group=task_group, + task=task, + function=function, + input_data=input, + form=form, + is_binary_input=is_binary_input, + api_protocol=api_protocol, + ) request_data.add_resources(resources) request_data.add_parameters(**kwargs) request.data = request_data - return request \ No newline at end of file + return request diff --git a/dashscope/api_entities/base_request.py b/dashscope/api_entities/base_request.py index 05db0d9..0b86d7a 100644 --- a/dashscope/api_entities/base_request.py +++ b/dashscope/api_entities/base_request.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import os @@ -9,7 +10,7 @@ class BaseRequest(ABC): - def __init__(self, user_agent: str = '') -> None: + def __init__(self, user_agent: str = "") -> None: try: platform_info = platform.platform() except Exception: @@ -20,23 +21,23 @@ def __init__(self, user_agent: str = '') -> None: except Exception: processor_info = "unknown" - ua = 'dashscope/%s; python/%s; platform/%s; processor/%s' % ( - __version__, - platform.python_version(), - platform_info, - processor_info, + ua = ( + f"dashscope/{__version__}; python/{platform.python_version()}; " + f"platform/{platform_info}; processor/{processor_info}" ) # Append user_agent if provided and not empty if user_agent: - ua += '; ' + user_agent + ua += "; " + user_agent - self.headers = {'user-agent': ua} + self.headers = {"user-agent": ua} disable_data_inspection = os.environ.get( - DASHSCOPE_DISABLE_DATA_INSPECTION_ENV, 'true') + DASHSCOPE_DISABLE_DATA_INSPECTION_ENV, + "true", + ) - if (disable_data_inspection.lower() == 'false'): - self.headers['X-DashScope-DataInspection'] = 'enable' + if disable_data_inspection.lower() == "false": + self.headers["X-DashScope-DataInspection"] = "enable" @abstractmethod def call(self): diff --git a/dashscope/api_entities/chat_completion_types.py b/dashscope/api_entities/chat_completion_types.py index 45c3696..5d8ac2f 100644 --- a/dashscope/api_entities/chat_completion_types.py +++ b/dashscope/api_entities/chat_completion_types.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. # adapter from openai sdk @@ -17,7 +18,8 @@ class CompletionUsage(BaseObjectMixin): total_tokens: int """Total number of tokens used in the request (prompt + completion).""" - def __init__(self, **kwargs): + + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) @@ -27,21 +29,22 @@ class TopLogprob(BaseObjectMixin): """The token.""" bytes: Optional[List[int]] = None - """A list of integers representing the UTF-8 bytes representation of the token. + """A list of integers representing the UTF-8 bytes representation of the token. # noqa: E501 Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token. """ - logprob: float - """The log probability of this token, if it is within the top 20 most likely + logprob: float # type: ignore[misc] + """The log probability of this token, if it is within the top 20 most likely # noqa: E501 tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely. """ - def __init__(self, **kwargs): + + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) @@ -51,33 +54,37 @@ class ChatCompletionTokenLogprob(BaseObjectMixin): """The token.""" bytes: Optional[List[int]] = None - """A list of integers representing the UTF-8 bytes representation of the token. + """A list of integers representing the UTF-8 bytes representation of the token. # noqa: E501 Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token. """ - logprob: float - """The log probability of this token, if it is within the top 20 most likely + logprob: float # type: ignore[misc] + """The log probability of this token, if it is within the top 20 most likely # noqa: E501 tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely. """ - top_logprobs: List[TopLogprob] + top_logprobs: List[TopLogprob] # type: ignore[misc] """List of the most likely tokens and their log probability, at this token position. - In rare cases, there may be fewer than the number of requested `top_logprobs` + In rare cases, there may be fewer than the number of requested `top_logprobs` # noqa: E501 returned. """ + def __init__(self, **kwargs): - if 'top_logprobs' in kwargs and kwargs[ - 'top_logprobs'] is not None and kwargs['top_logprobs']: + if ( + "top_logprobs" in kwargs + and kwargs["top_logprobs"] is not None + and kwargs["top_logprobs"] + ): top_logprobs = [] - for logprob in kwargs['top_logprobs']: + for logprob in kwargs["top_logprobs"]: top_logprobs.append(ChatCompletionTokenLogprob(**logprob)) self.top_logprobs = top_logprobs else: @@ -90,11 +97,15 @@ def __init__(self, **kwargs): class ChoiceLogprobs(BaseObjectMixin): content: Optional[List[ChatCompletionTokenLogprob]] = None """A list of message content tokens with log probability information.""" + def __init__(self, **kwargs): - if 'content' in kwargs and kwargs['content'] is not None and kwargs[ - 'content']: + if ( + "content" in kwargs + and kwargs["content"] is not None + and kwargs["content"] + ): logprobs = [] - for logprob in kwargs['content']: + for logprob in kwargs["content"]: logprobs.append(ChatCompletionTokenLogprob(**logprob)) self.content = logprobs else: @@ -115,7 +126,8 @@ class FunctionCall(BaseObjectMixin): name: str """The name of the function to call.""" - def __init__(self, **kwargs): + + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) @@ -131,7 +143,8 @@ class Function(BaseObjectMixin): name: str """The name of the function to call.""" - def __init__(self, **kwargs): + + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) @@ -143,12 +156,16 @@ class ChatCompletionMessageToolCall(BaseObjectMixin): function: Function """The function that the model called.""" - type: Literal['function'] + type: Literal["function"] """The type of the tool. Currently, only `function` is supported.""" + def __init__(self, **kwargs): - if 'function' in kwargs and kwargs['function'] is not None and kwargs[ - 'function']: - self.function = Function(**kwargs.pop('function', {})) + if ( + "function" in kwargs + and kwargs["function"] is not None + and kwargs["function"] + ): + self.function = Function(**kwargs.pop("function", {})) else: self.function = None @@ -160,28 +177,36 @@ class ChatCompletionMessage(BaseObjectMixin): content: Optional[str] = None """The contents of the message.""" - role: Literal['assistant'] + role: Literal["assistant"] # type: ignore[misc] """The role of the author of this message.""" function_call: Optional[FunctionCall] = None """Deprecated and replaced by `tool_calls`. - The name and arguments of a function that should be called, as generated by the + The name and arguments of a function that should be called, as generated by the # noqa: E501 model. """ tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None """The tool calls generated by the model, such as function calls.""" + def __init__(self, **kwargs): - if 'function_call' in kwargs and kwargs[ - 'function_call'] is not None and kwargs['function_call']: + if ( + "function_call" in kwargs + and kwargs["function_call"] is not None + and kwargs["function_call"] + ): self.function_call = FunctionCall( - **kwargs.pop('function_call', {})) - - if 'tool_calls' in kwargs and kwargs[ - 'tool_calls'] is not None and kwargs['tool_calls']: + **kwargs.pop("function_call", {}), + ) + + if ( + "tool_calls" in kwargs + and kwargs["tool_calls"] is not None + and kwargs["tool_calls"] + ): tool_calls = [] - for tool_call in kwargs['tool_calls']: + for tool_call in kwargs["tool_calls"]: tool_calls.append(ChatCompletionMessageToolCall(**tool_call)) self.tool_calls = tool_calls @@ -190,13 +215,18 @@ def __init__(self, **kwargs): @dataclass(init=False) class Choice(BaseObjectMixin): - finish_reason: Literal['stop', 'length', 'tool_calls', 'content_filter', - 'function_call'] + finish_reason: Literal[ + "stop", + "length", + "tool_calls", + "content_filter", + "function_call", + ] """The reason the model stopped generating tokens. - This will be `stop` if the model hit a natural stop point or a provided stop - sequence, `length` if the maximum number of tokens specified in the request was - reached, `content_filter` if content was omitted due to a flag from our content + This will be `stop` if the model hit a natural stop point or a provided stop # noqa: E501 + sequence, `length` if the maximum number of tokens specified in the request was # noqa: E501 + reached, `content_filter` if content was omitted due to a flag from our content # noqa: E501 filters, `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function. """ @@ -207,18 +237,25 @@ class Choice(BaseObjectMixin): logprobs: Optional[ChoiceLogprobs] = None """Log probability information for the choice.""" - message: ChatCompletionMessage + message: ChatCompletionMessage # type: ignore[misc] """A chat completion message generated by the model.""" + def __init__(self, **kwargs): - if 'message' in kwargs and kwargs['message'] is not None and kwargs[ - 'message']: - self.message = ChatCompletionMessage(**kwargs.pop('message', {})) + if ( + "message" in kwargs + and kwargs["message"] is not None + and kwargs["message"] + ): + self.message = ChatCompletionMessage(**kwargs.pop("message", {})) else: self.message = None - if 'logprobs' in kwargs and kwargs['logprobs'] is not None and kwargs[ - 'logprobs']: - self.logprobs = ChoiceLogprobs(**kwargs.pop('logprobs', {})) + if ( + "logprobs" in kwargs + and kwargs["logprobs"] is not None + and kwargs["logprobs"] + ): + self.logprobs = ChoiceLogprobs(**kwargs.pop("logprobs", {})) super().__init__(**kwargs) @@ -244,16 +281,16 @@ class ChatCompletion(BaseObjectMixin): """ created: int - """The Unix timestamp (in seconds) of when the chat completion was created.""" + """The Unix timestamp (in seconds) of when the chat completion was created.""" # noqa: E501 model: str """The model used for the chat completion.""" - object: Literal['chat.completion'] + object: Literal["chat.completion"] """The object type, which is always `chat.completion`.""" system_fingerprint: Optional[str] = None - """This fingerprint represents the backend configuration that the model runs with. + """This fingerprint represents the backend configuration that the model runs with. # noqa: E501 Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. @@ -261,17 +298,24 @@ class ChatCompletion(BaseObjectMixin): usage: Optional[CompletionUsage] = None """Usage statistics for the completion request.""" + def __init__(self, **kwargs): - if 'usage' in kwargs and kwargs['usage'] is not None and kwargs[ - 'usage']: - self.usage = CompletionUsage(**kwargs.pop('usage', {})) + if ( + "usage" in kwargs + and kwargs["usage"] is not None + and kwargs["usage"] + ): + self.usage = CompletionUsage(**kwargs.pop("usage", {})) else: self.usage = None - if 'choices' in kwargs and kwargs['choices'] is not None and kwargs[ - 'choices']: + if ( + "choices" in kwargs + and kwargs["choices"] is not None + and kwargs["choices"] + ): choices = [] - for choice in kwargs.pop('choices', []): + for choice in kwargs.pop("choices", []): choices.append(Choice(**choice)) self.choices = choices else: @@ -291,12 +335,12 @@ class ChatCompletionChunk(BaseObjectMixin): """The request failed, this is the error message. """ id: str - """A unique identifier for the chat completion. Each chunk has the same ID.""" + """A unique identifier for the chat completion. Each chunk has the same ID.""" # noqa: E501 choices: List[Choice] """A list of chat completion choices. - Can contain more than one elements if `n` is greater than 1. Can also be empty + Can contain more than one elements if `n` is greater than 1. Can also be empty # noqa: E501 for the last chunk if you set `stream_options: {"include_usage": true}`. """ @@ -309,13 +353,13 @@ class ChatCompletionChunk(BaseObjectMixin): model: str """The model to generate the completion.""" - object: Literal['chat.completion.chunk'] + object: Literal["chat.completion.chunk"] """The object type, which is always `chat.completion.chunk`.""" system_fingerprint: Optional[str] = None """ - This fingerprint represents the backend configuration that the model runs with. - Can be used in conjunction with the `seed` request parameter to understand when + This fingerprint represents the backend configuration that the model runs with. # noqa: E501 + Can be used in conjunction with the `seed` request parameter to understand when # noqa: E501 backend changes have been made that might impact determinism. """ @@ -323,20 +367,27 @@ class ChatCompletionChunk(BaseObjectMixin): """ An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request. When present, it - contains a null value except for the last chunk which contains the token usage + contains a null value except for the last chunk which contains the token usage # noqa: E501 statistics for the entire request. """ + def __init__(self, **kwargs): - if 'usage' in kwargs and kwargs['usage'] is not None and kwargs[ - 'usage']: - self.usage = CompletionUsage(**kwargs.pop('usage', {})) + if ( + "usage" in kwargs + and kwargs["usage"] is not None + and kwargs["usage"] + ): + self.usage = CompletionUsage(**kwargs.pop("usage", {})) else: self.usage = None - if 'choices' in kwargs and kwargs['choices'] is not None and kwargs[ - 'choices']: + if ( + "choices" in kwargs + and kwargs["choices"] is not None + and kwargs["choices"] + ): choices = [] - for choice in kwargs.pop('choices', []): + for choice in kwargs.pop("choices", []): choices.append(Choice(**choice)) self.choices = choices else: diff --git a/dashscope/api_entities/dashscope_response.py b/dashscope/api_entities/dashscope_response.py index f1840c0..a49cd23 100644 --- a/dashscope/api_entities/dashscope_response.py +++ b/dashscope/api_entities/dashscope_response.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -39,7 +40,7 @@ def get(self, key, default=None): def setdefault(self, key, default=None): return super().setdefault(key, default) - def pop(self, key, default: Any): + def pop(self, key, default: Any): # type: ignore[override] return super().pop(key, default) def update(self, **kwargs): @@ -64,7 +65,7 @@ def __setattr__(self, attr, value): self[attr] = value def __repr__(self): - return '{0}({1})'.format(type(self).__name__, super().__repr__()) + return f"{type(self).__name__}({super().__repr__()})" def __str__(self): return json.dumps(self, ensure_ascii=False) @@ -83,6 +84,7 @@ class DashScopeAPIResponse(DictMixin): output (Any): The request output. usage (Any): The request usage information. """ + status_code: int request_id: str code: str @@ -90,29 +92,33 @@ class DashScopeAPIResponse(DictMixin): output: Any usage: Any - def __init__(self, - status_code: int, - request_id: str = '', - code: str = '', - message: str = '', - output: Any = None, - usage: Any = None, - **kwargs): - super().__init__(status_code=status_code, - request_id=request_id, - code=code, - message=message, - output=output, - usage=usage, - **kwargs) + def __init__( + self, + status_code: int, + request_id: str = "", + code: str = "", + message: str = "", + output: Any = None, + usage: Any = None, + **kwargs, + ): + super().__init__( + status_code=status_code, + request_id=request_id, + code=code, + message=message, + output=output, + usage=usage, + **kwargs, + ) class Role: - USER = 'user' - SYSTEM = 'system' - BOT = 'bot' - ASSISTANT = 'assistant' - ATTACHMENT = 'attachment' + USER = "user" + SYSTEM = "system" + BOT = "bot" + ASSISTANT = "assistant" + ATTACHMENT = "attachment" class Message(DictMixin): @@ -124,11 +130,11 @@ def __init__(self, role: str, content: Union[str, List] = None, **kwargs): @classmethod def from_generation_response(cls, response: DictMixin): - if 'text' in response.output and response.output['text'] is not None: - content = response.output['text'] + if "text" in response.output and response.output["text"] is not None: + content = response.output["text"] return Message(role=Role.ASSISTANT, content=content) else: - return response.output.choices[0]['message'] + return response.output.choices[0]["message"] @classmethod def from_conversation_response(cls, response: DictMixin): @@ -140,16 +146,20 @@ class Choice(DictMixin): finish_reason: str message: Message - def __init__(self, - finish_reason: str = None, - message: Message = None, - **kwargs): + def __init__( + self, + finish_reason: str = None, + message: Message = None, + **kwargs, + ): msg_object = None if message is not None and message: msg_object = Message(**message) - super().__init__(finish_reason=finish_reason, - message=msg_object, - **kwargs) + super().__init__( + finish_reason=finish_reason, + message=msg_object, + **kwargs, + ) @dataclass(init=False) @@ -159,17 +169,22 @@ class Audio(DictMixin): id: str expires_at: int - def __init__(self, - data: str = None, - url: str = None, - id: str = None, - expires_at: int = None, - **kwargs): - super().__init__(data=data, - url=url, - id=id, - expires_at=expires_at, - **kwargs) + def __init__( + self, + data: str = None, + url: str = None, + # pylint: disable=redefined-builtin + id: str = None, + expires_at: int = None, + **kwargs, + ): + super().__init__( + data=data, + url=url, + id=id, + expires_at=expires_at, + **kwargs, + ) @dataclass(init=False) @@ -178,20 +193,24 @@ class GenerationOutput(DictMixin): choices: List[Choice] finish_reason: str - def __init__(self, - text: str = None, - finish_reason: str = None, - choices: List[Choice] = None, - **kwargs): + def __init__( + self, + text: str = None, + finish_reason: str = None, + choices: List[Choice] = None, + **kwargs, + ): chs = None if choices is not None: chs = [] for choice in choices: chs.append(Choice(**choice)) - super().__init__(text=text, - finish_reason=finish_reason, - choices=chs, - **kwargs) + super().__init__( + text=text, + finish_reason=finish_reason, + choices=chs, + **kwargs, + ) @dataclass(init=False) @@ -199,13 +218,17 @@ class GenerationUsage(DictMixin): input_tokens: int output_tokens: int - def __init__(self, - input_tokens: int = 0, - output_tokens: int = 0, - **kwargs): - super().__init__(input_tokens=input_tokens, - output_tokens=output_tokens, - **kwargs) + def __init__( + self, + input_tokens: int = 0, + output_tokens: int = 0, + **kwargs, + ): + super().__init__( + input_tokens=input_tokens, + output_tokens=output_tokens, + **kwargs, + ) @dataclass(init=False) @@ -226,12 +249,15 @@ def from_api_response(api_response: DashScopeAPIResponse): code=api_response.code, message=api_response.message, output=GenerationOutput(**api_response.output), - usage=GenerationUsage(**usage)) + usage=GenerationUsage(**usage), + ) else: - return GenerationResponse(status_code=api_response.status_code, - request_id=api_response.request_id, - code=api_response.code, - message=api_response.message) + return GenerationResponse( + status_code=api_response.status_code, + request_id=api_response.request_id, + code=api_response.code, + message=api_response.message, + ) @dataclass(init=False) @@ -239,12 +265,14 @@ class MultiModalConversationOutput(DictMixin): choices: List[Choice] audio: Audio - def __init__(self, - text: str = None, - finish_reason: str = None, - choices: List[Choice] = None, - audio: Audio = None, - **kwargs): + def __init__( + self, + text: str = None, + finish_reason: str = None, + choices: List[Choice] = None, + audio: Audio = None, + **kwargs, + ): chs = None if choices is not None: chs = [] @@ -252,11 +280,13 @@ def __init__(self, chs.append(Choice(**choice)) if audio is not None: audio = Audio(**audio) - super().__init__(text=text, - finish_reason=finish_reason, - choices=chs, - audio=audio, - **kwargs) + super().__init__( + text=text, + finish_reason=finish_reason, + choices=chs, + audio=audio, + **kwargs, + ) @dataclass(init=False) @@ -267,15 +297,19 @@ class MultiModalConversationUsage(DictMixin): # TODO add image usage info. - def __init__(self, - input_tokens: int = 0, - output_tokens: int = 0, - characters: int = 0, - **kwargs): - super().__init__(input_tokens=input_tokens, - output_tokens=output_tokens, - characters=characters, - **kwargs) + def __init__( + self, + input_tokens: int = 0, + output_tokens: int = 0, + characters: int = 0, + **kwargs, + ): + super().__init__( + input_tokens=input_tokens, + output_tokens=output_tokens, + characters=characters, + **kwargs, + ) @dataclass(init=False) @@ -296,13 +330,15 @@ def from_api_response(api_response: DashScopeAPIResponse): code=api_response.code, message=api_response.message, output=MultiModalConversationOutput(**api_response.output), - usage=MultiModalConversationUsage(**usage)) + usage=MultiModalConversationUsage(**usage), + ) else: return MultiModalConversationResponse( status_code=api_response.status_code, request_id=api_response.request_id, code=api_response.code, - message=api_response.message) + message=api_response.message, + ) @dataclass(init=False) @@ -340,18 +376,22 @@ def from_api_response(api_response: DashScopeAPIResponse): if api_response.usage is not None: usage = TranscriptionUsage(**api_response.usage) - return TranscriptionResponse(status_code=api_response.status_code, - request_id=api_response.request_id, - code=api_response.code, - message=api_response.message, - output=output, - usage=usage) + return TranscriptionResponse( + status_code=api_response.status_code, + request_id=api_response.request_id, + code=api_response.code, + message=api_response.message, + output=output, + usage=usage, + ) else: - return TranscriptionResponse(status_code=api_response.status_code, - request_id=api_response.request_id, - code=api_response.code, - message=api_response.message) + return TranscriptionResponse( + status_code=api_response.status_code, + request_id=api_response.request_id, + code=api_response.code, + message=api_response.message, + ) @dataclass(init=False) @@ -381,32 +421,39 @@ def from_api_response(api_response: DashScopeAPIResponse): output = None usage = None if api_response.output is not None: - if 'sentence' in api_response.output: + if "sentence" in api_response.output: output = RecognitionOutput(**api_response.output) if api_response.usage is not None: usage = RecognitionUsage(**api_response.usage) - return RecognitionResponse(status_code=api_response.status_code, - request_id=api_response.request_id, - code=api_response.code, - message=api_response.message, - output=output, - usage=usage) + return RecognitionResponse( + status_code=api_response.status_code, + request_id=api_response.request_id, + code=api_response.code, + message=api_response.message, + output=output, + usage=usage, + ) else: - return RecognitionResponse(status_code=api_response.status_code, - request_id=api_response.request_id, - code=api_response.code, - message=api_response.message) + return RecognitionResponse( + status_code=api_response.status_code, + request_id=api_response.request_id, + code=api_response.code, + message=api_response.message, + ) @staticmethod def is_sentence_end(sentence: Dict[str, Any]) -> bool: - """Determine whether the speech recognition result is the end of a sentence. - This is a static method. + """Determine whether the speech recognition result is the end of a sentence. # noqa: E501 + This is a static method. """ result = False - if sentence is not None and 'end_time' in sentence and sentence[ - 'end_time'] is not None: + if ( + sentence is not None + and "end_time" in sentence + and sentence["end_time"] is not None + ): result = True return result @@ -448,21 +495,23 @@ def from_api_response(api_response: DashScopeAPIResponse): code=api_response.code, message=api_response.message, output=output, - usage=usage) + usage=usage, + ) else: return SpeechSynthesisResponse( status_code=api_response.status_code, request_id=api_response.request_id, code=api_response.code, - message=api_response.message) + message=api_response.message, + ) @dataclass(init=False) class ImageSynthesisResult(DictMixin): url: str - def __init__(self, url: str = '', **kwargs) -> None: + def __init__(self, url: str = "", **kwargs) -> None: super().__init__(url=url, **kwargs) @@ -471,21 +520,26 @@ class ImageSynthesisOutput(DictMixin): task_id: str task_status: str results: List[ImageSynthesisResult] - - def __init__(self, - task_id: str = None, - task_status: str = None, - results: List[ImageSynthesisResult] = [], - **kwargs): + # pylint: disable=dangerous-default-value + + def __init__( + self, + task_id: str = None, + task_status: str = None, + results: List[ImageSynthesisResult] = [], + **kwargs, + ): res = [] if len(results) > 0: for result in results: res.append(ImageSynthesisResult(**result)) - super().__init__(self, - task_id=task_id, - task_status=task_status, - results=res, - **kwargs) + super().__init__( + self, + task_id=task_id, + task_status=task_status, + results=res, + **kwargs, + ) @dataclass(init=False) @@ -494,16 +548,20 @@ class VideoSynthesisOutput(DictMixin): task_status: str video_url: str - def __init__(self, - task_id: str, - task_status: str, - video_url: str = '', - **kwargs): - super().__init__(self, - task_id=task_id, - task_status=task_status, - video_url=video_url, - **kwargs) + def __init__( + self, + task_id: str, + task_status: str, + video_url: str = "", + **kwargs, + ): + super().__init__( + self, + task_id=task_id, + task_status=task_status, + video_url=video_url, + **kwargs, + ) @dataclass(init=False) @@ -520,15 +578,19 @@ class VideoSynthesisUsage(DictMixin): video_duration: int video_ratio: str - def __init__(self, - video_count: int = 1, - video_duration: int = 0, - video_ratio: str = '', - **kwargs): - super().__init__(video_count=video_count, - video_duration=video_duration, - video_ratio=video_ratio, - **kwargs) + def __init__( + self, + video_count: int = 1, + video_duration: int = 0, + video_ratio: str = "", + **kwargs, + ): + super().__init__( + video_count=video_count, + video_duration=video_duration, + video_ratio=video_ratio, + **kwargs, + ) @dataclass(init=False) @@ -546,18 +608,22 @@ def from_api_response(api_response: DashScopeAPIResponse): if api_response.usage is not None: usage = ImageSynthesisUsage(**api_response.usage) - return ImageSynthesisResponse(status_code=api_response.status_code, - request_id=api_response.request_id, - code=api_response.code, - message=api_response.message, - output=output, - usage=usage) + return ImageSynthesisResponse( + status_code=api_response.status_code, + request_id=api_response.request_id, + code=api_response.code, + message=api_response.message, + output=output, + usage=usage, + ) else: - return ImageSynthesisResponse(status_code=api_response.status_code, - request_id=api_response.request_id, - code=api_response.code, - message=api_response.message) + return ImageSynthesisResponse( + status_code=api_response.status_code, + request_id=api_response.request_id, + code=api_response.code, + message=api_response.message, + ) @dataclass(init=False) @@ -575,18 +641,22 @@ def from_api_response(api_response: DashScopeAPIResponse): if api_response.usage is not None: usage = VideoSynthesisUsage(**api_response.usage) - return VideoSynthesisResponse(status_code=api_response.status_code, - request_id=api_response.request_id, - code=api_response.code, - message=api_response.message, - output=output, - usage=usage) + return VideoSynthesisResponse( + status_code=api_response.status_code, + request_id=api_response.request_id, + code=api_response.code, + message=api_response.message, + output=output, + usage=usage, + ) else: - return VideoSynthesisResponse(status_code=api_response.status_code, - request_id=api_response.request_id, - code=api_response.code, - message=api_response.message) + return VideoSynthesisResponse( + status_code=api_response.status_code, + request_id=api_response.request_id, + code=api_response.code, + message=api_response.message, + ) @dataclass(init=False) @@ -595,15 +665,19 @@ class ReRankResult(DictMixin): relevance_score: float document: Dict = None - def __init__(self, - index: int, - relevance_score: float, - document: Dict = None, - **kwargs): - super().__init__(index=index, - relevance_score=relevance_score, - document=document, - **kwargs) + def __init__( + self, + index: int, + relevance_score: float, + document: Dict = None, + **kwargs, + ): + super().__init__( + index=index, + relevance_score=relevance_score, + document=document, + **kwargs, + ) @dataclass(init=False) @@ -639,17 +713,21 @@ def from_api_response(api_response: DashScopeAPIResponse): if api_response.usage: usage = api_response.usage - return ReRankResponse(status_code=api_response.status_code, - request_id=api_response.request_id, - code=api_response.code, - message=api_response.message, - output=ReRankOutput(**api_response.output), - usage=ReRankUsage(**usage)) + return ReRankResponse( + status_code=api_response.status_code, + request_id=api_response.request_id, + code=api_response.code, + message=api_response.message, + output=ReRankOutput(**api_response.output), + usage=ReRankUsage(**usage), + ) else: - return ReRankResponse(status_code=api_response.status_code, - request_id=api_response.request_id, - code=api_response.code, - message=api_response.message) + return ReRankResponse( + status_code=api_response.status_code, + request_id=api_response.request_id, + code=api_response.code, + message=api_response.message, + ) @dataclass(init=False) @@ -659,17 +737,22 @@ class TextToSpeechAudio(DictMixin): data: str url: str - def __init__(self, - expires_at: int, - id: str, - data: str = None, - url: str = None, - **kwargs): - super().__init__(expires_at=expires_at, - id=id, - data=data, - url=url, - **kwargs) + def __init__( + # pylint: disable=redefined-builtin + self, + expires_at: int, + id: str, + data: str = None, + url: str = None, + **kwargs, + ): + super().__init__( + expires_at=expires_at, + id=id, + data=data, + url=url, + **kwargs, + ) @dataclass(init=False) @@ -677,13 +760,17 @@ class TextToSpeechOutput(DictMixin): finish_reason: str audio: TextToSpeechAudio - def __init__(self, - finish_reason: str = None, - audio: TextToSpeechAudio = None, - **kwargs): - super().__init__(finish_reason=finish_reason, - audio=audio, - **kwargs) + def __init__( + self, + finish_reason: str = None, + audio: TextToSpeechAudio = None, + **kwargs, + ): + super().__init__( + finish_reason=finish_reason, + audio=audio, + **kwargs, + ) @dataclass(init=False) @@ -704,13 +791,15 @@ def from_api_response(api_response: DashScopeAPIResponse): code=api_response.code, message=api_response.message, output=TextToSpeechOutput(**api_response.output), - usage=MultiModalConversationUsage(**usage)) + usage=MultiModalConversationUsage(**usage), + ) else: return TextToSpeechResponse( status_code=api_response.status_code, request_id=api_response.request_id, code=api_response.code, - message=api_response.message) + message=api_response.message, + ) @dataclass(init=False) @@ -718,12 +807,14 @@ class ImageGenerationOutput(DictMixin): choices: List[Choice] audio: Audio - def __init__(self, - text: str = None, - finish_reason: str = None, - choices: List[Choice] = None, - audio: Audio = None, - **kwargs): + def __init__( + self, + text: str = None, + finish_reason: str = None, + choices: List[Choice] = None, + audio: Audio = None, + **kwargs, + ): chs = None if choices is not None: chs = [] @@ -731,11 +822,13 @@ def __init__(self, chs.append(Choice(**choice)) if audio is not None: audio = Audio(**audio) - super().__init__(text=text, - finish_reason=finish_reason, - choices=chs, - audio=audio, - **kwargs) + super().__init__( + text=text, + finish_reason=finish_reason, + choices=chs, + audio=audio, + **kwargs, + ) @dataclass(init=False) @@ -746,15 +839,20 @@ class ImageGenerationUsage(DictMixin): # TODO add image usage info. - def __init__(self, - input_tokens: int = 0, - output_tokens: int = 0, - characters: int = 0, - **kwargs): - super().__init__(input_tokens=input_tokens, - output_tokens=output_tokens, - characters=characters, - **kwargs) + def __init__( + self, + input_tokens: int = 0, + output_tokens: int = 0, + characters: int = 0, + **kwargs, + ): + super().__init__( + input_tokens=input_tokens, + output_tokens=output_tokens, + characters=characters, + **kwargs, + ) + @dataclass(init=False) class ImageGenerationResponse(DashScopeAPIResponse): @@ -774,10 +872,12 @@ def from_api_response(api_response: DashScopeAPIResponse): code=api_response.code, message=api_response.message, output=ImageGenerationOutput(**api_response.output), - usage=ImageGenerationUsage(**usage)) + usage=ImageGenerationUsage(**usage), + ) else: return ImageGenerationResponse( status_code=api_response.status_code, request_id=api_response.request_id, code=api_response.code, - message=api_response.message) + message=api_response.message, + ) diff --git a/dashscope/api_entities/encryption.py b/dashscope/api_entities/encryption.py index d51a53e..9f8d8ba 100644 --- a/dashscope/api_entities/encryption.py +++ b/dashscope/api_entities/encryption.py @@ -1,29 +1,31 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import base64 import json -from dataclasses import dataclass import os -from typing import Optional import requests from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from cryptography.hazmat.primitives import serialization, hashes +from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.backends import default_backend import dashscope -from dashscope.common.constants import ENCRYPTION_AES_SECRET_KEY_BYTES, ENCRYPTION_AES_IV_LENGTH +from dashscope.common.constants import ( + ENCRYPTION_AES_SECRET_KEY_BYTES, + ENCRYPTION_AES_IV_LENGTH, +) from dashscope.common.logging import logger class Encryption: def __init__(self): - self.pub_key_id: str = '' - self.pub_key_str: str = '' - self.aes_key_bytes: bytes = b'' - self.encrypted_aes_key_str: str = '' - self.iv_bytes: bytes = b'' - self.base64_iv_str: str = '' + self.pub_key_id: str = "" + self.pub_key_str: str = "" + self.aes_key_bytes: bytes = b"" + self.encrypted_aes_key_str: str = "" + self.iv_bytes: bytes = b"" + self.base64_iv_str: str = "" self.valid: bool = False def initialize(self): @@ -31,8 +33,8 @@ def initialize(self): if not public_keys: return - public_key_str = public_keys.get('public_key') - public_key_id = public_keys.get('public_key_id') + public_key_str = public_keys.get("public_key") + public_key_id = public_keys.get("public_key_id") if not public_key_str or not public_key_id: logger.error("public keys data not valid") return @@ -40,8 +42,11 @@ def initialize(self): aes_key_bytes = self._generate_aes_secret_key() iv_bytes = self._generate_iv() - encrypted_aes_key_str = self._encrypt_aes_key_with_rsa(aes_key_bytes, public_key_str) - base64_iv_str = base64.b64encode(iv_bytes).decode('utf-8') + encrypted_aes_key_str = self._encrypt_aes_key_with_rsa( + aes_key_bytes, + public_key_str, + ) + base64_iv_str = base64.b64encode(iv_bytes).decode("utf-8") self.pub_key_id = public_key_id self.pub_key_str = public_key_str @@ -53,11 +58,18 @@ def initialize(self): self.valid = True def encrypt(self, dict_plaintext): - return self._encrypt_text_with_aes(json.dumps(dict_plaintext, ensure_ascii=False), - self.aes_key_bytes, self.iv_bytes) + return self._encrypt_text_with_aes( + json.dumps(dict_plaintext, ensure_ascii=False), + self.aes_key_bytes, + self.iv_bytes, + ) def decrypt(self, base64_ciphertext): - return self._decrypt_text_with_aes(base64_ciphertext, self.aes_key_bytes, self.iv_bytes) + return self._decrypt_text_with_aes( + base64_ciphertext, + self.aes_key_bytes, + self.iv_bytes, + ) def is_valid(self): return self.valid @@ -73,18 +85,18 @@ def get_base64_iv_str(self): @staticmethod def _get_public_keys(): - url = dashscope.base_http_api_url + '/public-keys/latest' + url = dashscope.base_http_api_url + "/public-keys/latest" headers = { - "Authorization": f"Bearer {dashscope.api_key}" + "Authorization": f"Bearer {dashscope.api_key}", } response = requests.get(url, headers=headers) if response.status_code != 200: - logger.error("exceptional public key response: %s" % response) + logger.error("exceptional public key response: %s", response) return None json_resp = response.json() - response_data = json_resp.get('data') + response_data = json_resp.get("data") if not response_data: logger.error("no valid data in public key response") @@ -108,14 +120,16 @@ def _encrypt_text_with_aes(plaintext, key, iv): aes_gcm = Cipher( algorithms.AES(key), modes.GCM(iv, tag=None), - backend=default_backend() + backend=default_backend(), ).encryptor() # 关联数据设为空(根据需求可调整) - aes_gcm.authenticate_additional_data(b'') + aes_gcm.authenticate_additional_data(b"") # 加密数据 - ciphertext = aes_gcm.update(plaintext.encode('utf-8')) + aes_gcm.finalize() + ciphertext = ( + aes_gcm.update(plaintext.encode("utf-8")) + aes_gcm.finalize() + ) # 获取认证标签 tag = aes_gcm.tag @@ -124,7 +138,7 @@ def _encrypt_text_with_aes(plaintext, key, iv): encrypted_data = ciphertext + tag # 返回Base64编码结果 - return base64.b64encode(encrypted_data).decode('utf-8') + return base64.b64encode(encrypted_data).decode("utf-8") @staticmethod def _decrypt_text_with_aes(base64_ciphertext, aes_key, iv): @@ -141,17 +155,17 @@ def _decrypt_text_with_aes(base64_ciphertext, aes_key, iv): aes_gcm = Cipher( algorithms.AES(aes_key), modes.GCM(iv, tag), - backend=default_backend() + backend=default_backend(), ).decryptor() # 验证关联数据(与加密时一致) - aes_gcm.authenticate_additional_data(b'') + aes_gcm.authenticate_additional_data(b"") # 解密数据 decrypted_bytes = aes_gcm.update(ciphertext) + aes_gcm.finalize() # 明文 - plaintext = decrypted_bytes.decode('utf-8') + plaintext = decrypted_bytes.decode("utf-8") return json.loads(plaintext) @@ -165,15 +179,15 @@ def _encrypt_aes_key_with_rsa(aes_key, public_key_str): # 加载公钥 public_key = serialization.load_der_public_key( public_key_bytes, - backend=default_backend() + backend=default_backend(), ) - base64_aes_key = base64.b64encode(aes_key).decode('utf-8') + base64_aes_key = base64.b64encode(aes_key).decode("utf-8") # 使用RSA加密 encrypted_bytes = public_key.encrypt( - base64_aes_key.encode('utf-8'), - padding.PKCS1v15() + base64_aes_key.encode("utf-8"), + padding.PKCS1v15(), ) - return base64.b64encode(encrypted_bytes).decode('utf-8') + return base64.b64encode(encrypted_bytes).decode("utf-8") diff --git a/dashscope/api_entities/http_request.py b/dashscope/api_entities/http_request.py index 386a03e..d84dbc3 100644 --- a/dashscope/api_entities/http_request.py +++ b/dashscope/api_entities/http_request.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import datetime import json @@ -11,30 +12,37 @@ from dashscope.api_entities.base_request import AioBaseRequest from dashscope.api_entities.dashscope_response import DashScopeAPIResponse -from dashscope.common.constants import (DEFAULT_REQUEST_TIMEOUT_SECONDS, - SSE_CONTENT_TYPE, HTTPMethod) +from dashscope.common.constants import ( + DEFAULT_REQUEST_TIMEOUT_SECONDS, + SSE_CONTENT_TYPE, + HTTPMethod, +) from dashscope.common.error import UnsupportedHTTPMethod from dashscope.common.logging import logger -from dashscope.common.utils import (_handle_aio_stream, - _handle_aiohttp_failed_response, - _handle_http_failed_response, - _handle_stream) +from dashscope.common.utils import ( + _handle_aio_stream, + _handle_aiohttp_failed_response, + _handle_http_failed_response, + _handle_stream, +) from dashscope.api_entities.encryption import Encryption class HttpRequest(AioBaseRequest): - def __init__(self, - url: str, - api_key: str, - http_method: str, - stream: bool = True, - async_request: bool = False, - query: bool = False, - timeout: int = DEFAULT_REQUEST_TIMEOUT_SECONDS, - task_id: str = None, - flattened_output: bool = False, - encryption: Optional[Encryption] = None, - user_agent: str = '') -> None: + def __init__( + self, + url: str, + api_key: str, + http_method: str, + stream: bool = True, + async_request: bool = False, + query: bool = False, + timeout: int = DEFAULT_REQUEST_TIMEOUT_SECONDS, + task_id: str = None, + flattened_output: bool = False, + encryption: Optional[Encryption] = None, + user_agent: str = "", + ) -> None: """HttpSSERequest, processing http server sent event stream. Args: @@ -54,43 +62,45 @@ def __init__(self, self.async_request = async_request self.encryption = encryption self.headers = { - 'Accept': 'application/json', - 'Authorization': 'Bearer %s' % api_key, + "Accept": "application/json", + "Authorization": f"Bearer {api_key}", **self.headers, } if encryption and encryption.is_valid(): self.headers = { - "X-DashScope-EncryptionKey": json.dumps({ - "public_key_id": encryption.get_pub_key_id(), - "encrypt_key": encryption.get_encrypted_aes_key_str(), - "iv": encryption.get_base64_iv_str() - }), + "X-DashScope-EncryptionKey": json.dumps( + { + "public_key_id": encryption.get_pub_key_id(), + "encrypt_key": encryption.get_encrypted_aes_key_str(), + "iv": encryption.get_base64_iv_str(), + }, + ), **self.headers, } self.query = query if self.async_request and self.query is False: self.headers = { - 'X-DashScope-Async': 'enable', + "X-DashScope-Async": "enable", **self.headers, } self.method = http_method if self.method == HTTPMethod.POST: - self.headers['Content-Type'] = 'application/json' + self.headers["Content-Type"] = "application/json" self.stream = stream if self.stream: - self.headers['Accept'] = SSE_CONTENT_TYPE - self.headers['X-Accel-Buffering'] = 'no' - self.headers['X-DashScope-SSE'] = 'enable' + self.headers["Accept"] = SSE_CONTENT_TYPE + self.headers["X-Accel-Buffering"] = "no" + self.headers["X-DashScope-SSE"] = "enable" if self.query: - self.url = self.url.replace('api', 'api-task') - self.url += '%s' % task_id + self.url = self.url.replace("api", "api-task") + self.url += f"{task_id}" if timeout is None: self.timeout = DEFAULT_REQUEST_TIMEOUT_SECONDS else: - self.timeout = timeout + self.timeout = timeout # type: ignore[has-type] def add_header(self, key, value): self.headers[key] = value @@ -126,40 +136,50 @@ async def _handle_aio_request(self): try: connector = aiohttp.TCPConnector( ssl=ssl.create_default_context( - cafile=certifi.where())) + cafile=certifi.where(), + ), + ) async with aiohttp.ClientSession( - connector=connector, - timeout=aiohttp.ClientTimeout(total=self.timeout), - headers=self.headers) as session: - logger.debug('Starting request: %s' % self.url) + connector=connector, + timeout=aiohttp.ClientTimeout(total=self.timeout), + headers=self.headers, + ) as session: + logger.debug("Starting request: %s", self.url) if self.method == HTTPMethod.POST: is_form, obj = False, {} - if hasattr(self, 'data') and self.data is not None: + if hasattr(self, "data") and self.data is not None: is_form, obj = self.data.get_aiohttp_payload() if is_form: headers = {**self.headers, **obj.headers} - response = await session.post(url=self.url, - data=obj, - headers=headers) + response = await session.post( + url=self.url, + data=obj, + headers=headers, + ) else: - response = await session.request('POST', - url=self.url, - json=obj, - headers=self.headers) + response = await session.request( + "POST", + url=self.url, + json=obj, + headers=self.headers, + ) elif self.method == HTTPMethod.GET: # 添加条件判断 params = {} - if hasattr(self, 'data') and self.data is not None: - params = getattr(self.data, 'parameters', {}) + if hasattr(self, "data") and self.data is not None: + params = getattr(self.data, "parameters", {}) if params: params = self.__handle_parameters(params) - response = await session.get(url=self.url, - params=params, - headers=self.headers) + response = await session.get( + url=self.url, + params=params, + headers=self.headers, + ) else: - raise UnsupportedHTTPMethod('Unsupported http method: %s' % - self.method) - logger.debug('Response returned: %s' % self.url) + raise UnsupportedHTTPMethod( + f"Unsupported http method: {self.method}", + ) + logger.debug("Response returned: %s", self.url) async with response: async for rsp in self._handle_aio_response(response): yield rsp @@ -172,175 +192,217 @@ async def _handle_aio_request(self): @staticmethod def __handle_parameters(params: dict) -> dict: + # pylint: disable=too-many-return-statements def __format(value): if isinstance(value, bool): return str(value).lower() elif isinstance(value, (str, int, float)): return value elif value is None: - return '' + return "" elif isinstance(value, (datetime.datetime, datetime.date)): return value.isoformat() elif isinstance(value, (list, tuple)): - return ','.join(str(__format(x)) for x in value) + return ",".join(str(__format(x)) for x in value) elif isinstance(value, dict): return json.dumps(value) else: try: return str(value) except Exception as e: - raise ValueError(f"Unsupported type {type(value)} for param formatting: {e}") + # pylint: disable=raise-missing-from + raise ValueError( + f"Unsupported type {type(value)} for param formatting: {e}", # noqa: E501 + ) formatted = {} for k, v in params.items(): formatted[k] = __format(v) + # pylint: disable=too-many-statements return formatted - async def _handle_aio_response(self, response: aiohttp.ClientResponse): - request_id = '' - if (response.status == HTTPStatus.OK and self.stream - and SSE_CONTENT_TYPE in response.content_type): + async def _handle_aio_response( # pylint: disable=too-many-branches, too-many-statements # noqa: E501 + self, + response: aiohttp.ClientResponse, + ): + request_id = "" + if ( + response.status == HTTPStatus.OK + and self.stream + and SSE_CONTENT_TYPE in response.content_type + ): async for is_error, status_code, data in _handle_aio_stream( - response): + response, + ): try: output = None usage = None msg = json.loads(data) if not is_error: - if 'output' in msg: - output = msg['output'] - if 'usage' in msg: - usage = msg['usage'] - if 'request_id' in msg: - request_id = msg['request_id'] + if "output" in msg: + output = msg["output"] + if "usage" in msg: + usage = msg["usage"] + if "request_id" in msg: + request_id = msg["request_id"] except json.JSONDecodeError: yield DashScopeAPIResponse( request_id=request_id, status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - code='Unknown', - message=data) + code="Unknown", + message=data, + ) continue if is_error: - yield DashScopeAPIResponse(request_id=request_id, - status_code=status_code, - code=msg['code'], - message=msg['message']) + yield DashScopeAPIResponse( + request_id=request_id, + status_code=status_code, + code=msg["code"], + message=msg["message"], + ) else: if self.encryption and self.encryption.is_valid(): output = self.encryption.decrypt(output) - yield DashScopeAPIResponse(request_id=request_id, - status_code=HTTPStatus.OK, - output=output, - usage=usage) - elif (response.status == HTTPStatus.OK - and 'multipart' in response.content_type): + yield DashScopeAPIResponse( + request_id=request_id, + status_code=HTTPStatus.OK, + output=output, + usage=usage, + ) + elif ( + response.status == HTTPStatus.OK + and "multipart" in response.content_type + ): reader = aiohttp.MultipartReader.from_response(response) output = {} while True: part = await reader.next() if part is None: + # pylint: disable=consider-using-get break output[part.name] = await part.read() - if 'request_id' in output: - request_id = output['request_id'] + if "request_id" in output: # pylint: disable=consider-using-get + request_id = output["request_id"] if self.encryption and self.encryption.is_valid(): output = self.encryption.decrypt(output) - yield DashScopeAPIResponse(request_id=request_id, - status_code=HTTPStatus.OK, - output=output) + yield DashScopeAPIResponse( + request_id=request_id, + status_code=HTTPStatus.OK, + output=output, + ) elif response.status == HTTPStatus.OK: json_content = await response.json() output = None usage = None - if 'output' in json_content and json_content['output'] is not None: - output = json_content['output'] + if "output" in json_content and json_content["output"] is not None: + output = json_content["output"] # Compatible with wan - elif 'data' in json_content and json_content['data'] is not None\ - and isinstance(json_content['data'], list)\ - and len(json_content['data']) > 0\ - and 'task_id' in json_content['data'][0]: + elif ( + "data" in json_content + and json_content["data"] is not None + and isinstance(json_content["data"], list) + and len(json_content["data"]) > 0 + and "task_id" in json_content["data"][0] + ): output = json_content - if 'usage' in json_content: - usage = json_content['usage'] - if 'request_id' in json_content: - request_id = json_content['request_id'] + if "usage" in json_content: + usage = json_content["usage"] + if "request_id" in json_content: + request_id = json_content["request_id"] if self.encryption and self.encryption.is_valid(): output = self.encryption.decrypt(output) - yield DashScopeAPIResponse(request_id=request_id, - status_code=HTTPStatus.OK, - output=output, - usage=usage) + yield DashScopeAPIResponse( + request_id=request_id, + status_code=HTTPStatus.OK, + output=output, + usage=usage, + ) else: yield await _handle_aiohttp_failed_response(response) - def _handle_response(self, response: requests.Response): - request_id = '' - if (response.status_code == HTTPStatus.OK and self.stream - and SSE_CONTENT_TYPE in response.headers.get( - 'content-type', '')): + def _handle_response( # pylint: disable=too-many-branches + self, + response: requests.Response, + ): + request_id = "" + if ( + response.status_code == HTTPStatus.OK + and self.stream + and SSE_CONTENT_TYPE + in response.headers.get( + "content-type", + "", + ) + ): for is_error, status_code, event in _handle_stream(response): try: data = event.data output = None usage = None msg = json.loads(data) - logger.debug('Stream message: %s' % msg) + logger.debug("Stream message: %s", msg) if not is_error: - if 'output' in msg: - output = msg['output'] - if 'usage' in msg: - usage = msg['usage'] - if 'request_id' in msg: - request_id = msg['request_id'] + if "output" in msg: + output = msg["output"] + if "usage" in msg: + usage = msg["usage"] + if "request_id" in msg: + request_id = msg["request_id"] except json.JSONDecodeError: yield DashScopeAPIResponse( request_id=request_id, status_code=HTTPStatus.BAD_REQUEST, output=None, - code='Unknown', - message=data) + code="Unknown", + message=data, + ) continue if is_error: yield DashScopeAPIResponse( request_id=request_id, status_code=status_code, output=None, - code=msg['code'] - if 'code' in msg else None, # noqa E501 - message=msg['message'] - if 'message' in msg else None) # noqa E501 + code=msg["code"] + if "code" in msg + else None, # noqa E501 + message=msg["message"] if "message" in msg else None, + ) # noqa E501 else: if self.flattened_output: yield msg else: if self.encryption and self.encryption.is_valid(): output = self.encryption.decrypt(output) - yield DashScopeAPIResponse(request_id=request_id, - status_code=HTTPStatus.OK, - output=output, - usage=usage) + yield DashScopeAPIResponse( + request_id=request_id, + status_code=HTTPStatus.OK, + output=output, + usage=usage, + ) elif response.status_code == HTTPStatus.OK: json_content = response.json() - logger.debug('Response: %s' % json_content) + logger.debug("Response: %s", json_content) output = None usage = None - if 'task_id' in json_content: - output = {'task_id': json_content['task_id']} - if 'output' in json_content: - output = json_content['output'] - if 'usage' in json_content: - usage = json_content['usage'] - if 'request_id' in json_content: - request_id = json_content['request_id'] + if "task_id" in json_content: + output = {"task_id": json_content["task_id"]} + if "output" in json_content: + output = json_content["output"] + if "usage" in json_content: + usage = json_content["usage"] + if "request_id" in json_content: + request_id = json_content["request_id"] if self.flattened_output: yield json_content else: if self.encryption and self.encryption.is_valid(): output = self.encryption.decrypt(output) - yield DashScopeAPIResponse(request_id=request_id, - status_code=HTTPStatus.OK, - output=output, - usage=usage) + yield DashScopeAPIResponse( + request_id=request_id, + status_code=HTTPStatus.OK, + output=output, + usage=usage, + ) else: yield _handle_http_failed_response(response) @@ -351,27 +413,34 @@ def _handle_request(self): is_form, form, obj = self.data.get_http_payload() if is_form: headers = {**self.headers} - headers.pop('Content-Type') - response = session.post(url=self.url, - data=obj, - files=form, - headers=headers, - timeout=self.timeout) + headers.pop("Content-Type") + response = session.post( + url=self.url, + data=obj, + files=form, + headers=headers, + timeout=self.timeout, + ) else: - logger.debug('Request body: %s' % obj) - response = session.post(url=self.url, - stream=self.stream, - json=obj, - headers={**self.headers}, - timeout=self.timeout) + logger.debug("Request body: %s", obj) + response = session.post( + url=self.url, + stream=self.stream, + json=obj, + headers={**self.headers}, + timeout=self.timeout, + ) elif self.method == HTTPMethod.GET: - response = session.get(url=self.url, - params=self.data.parameters, - headers=self.headers, - timeout=self.timeout) + response = session.get( + url=self.url, + params=self.data.parameters, + headers=self.headers, + timeout=self.timeout, + ) else: - raise UnsupportedHTTPMethod('Unsupported http method: %s' % - self.method) + raise UnsupportedHTTPMethod( + f"Unsupported http method: {self.method}", + ) for rsp in self._handle_response(response): yield rsp except BaseException as e: diff --git a/dashscope/api_entities/websocket_request.py b/dashscope/api_entities/websocket_request.py index 2d99787..0473709 100644 --- a/dashscope/api_entities/websocket_request.py +++ b/dashscope/api_entities/websocket_request.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio @@ -10,17 +11,29 @@ from dashscope.api_entities.base_request import AioBaseRequest from dashscope.api_entities.dashscope_response import DashScopeAPIResponse -from dashscope.common.constants import (DEFAULT_REQUEST_TIMEOUT_SECONDS, - SERVICE_503_MESSAGE, - WEBSOCKET_ERROR_CODE) -from dashscope.common.error import (RequestFailure, UnexpectedMessageReceived, - UnknownMessageReceived) +from dashscope.common.constants import ( + DEFAULT_REQUEST_TIMEOUT_SECONDS, + SERVICE_503_MESSAGE, + WEBSOCKET_ERROR_CODE, +) +from dashscope.common.error import ( + RequestFailure, + UnexpectedMessageReceived, + UnknownMessageReceived, +) from dashscope.common.logging import logger from dashscope.common.utils import async_to_sync -from dashscope.protocol.websocket import (ACTION_KEY, ERROR_MESSAGE, - ERROR_NAME, EVENT_KEY, HEADER, - TASK_ID, ActionType, EventType, - WebsocketStreamingMode) +from dashscope.protocol.websocket import ( + ACTION_KEY, + ERROR_MESSAGE, + ERROR_NAME, + EVENT_KEY, + HEADER, + TASK_ID, + ActionType, + EventType, + WebsocketStreamingMode, +) class WebSocketRequest(AioBaseRequest): @@ -34,9 +47,10 @@ def __init__( timeout: int = DEFAULT_REQUEST_TIMEOUT_SECONDS, flattened_output: bool = False, pre_task_id=None, - user_agent: str = '', + user_agent: str = "", ) -> None: super().__init__(user_agent=user_agent) + # pylint: disable=pointless-string-statement """HttpRequest. Args: @@ -53,17 +67,17 @@ def __init__( if timeout is None: self.timeout = DEFAULT_REQUEST_TIMEOUT_SECONDS else: - self.timeout = timeout + self.timeout = timeout # type: ignore[has-type] self.ws_stream_mode = ws_stream_mode self.is_binary_input = is_binary_input self.headers = { - 'Authorization': 'bearer %s' % api_key, + "Authorization": f"bearer {api_key}", **self.headers, } self.task_headers = { - 'streaming': self.ws_stream_mode, + "streaming": self.ws_stream_mode, } self.pre_task_id = pre_task_id @@ -98,119 +112,158 @@ async def aio_call(self): pass return result - async def connection_handler(self): + async def connection_handler(self): # pylint: disable=too-many-branches try: task_id = None - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout( - total=self.timeout)) as session: - async with session.ws_connect(self.url, - headers=self.headers, - heartbeat=6000) as ws: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout( + total=self.timeout, + ), + ) as session: + async with session.ws_connect( + self.url, + headers=self.headers, + heartbeat=6000, + ) as ws: await self._start_task(ws) # send start task action. - task_id = self.task_headers['task_id'] + task_id = self.task_headers["task_id"] await self._wait_for_task_started( - ws) # wait for task started event. # noqa E501 + ws, + ) # wait for task started event. # noqa E501 if self.ws_stream_mode == WebsocketStreamingMode.NONE: if self.is_binary_input: # send the binary package data = self.data.get_batch_binary_data() await ws.send_bytes(list(data.values())[0]) - is_binary, result = await self._receive_batch_data_task( # noqa E501 - ws) + ( + is_binary, + result, + ) = await self._receive_batch_data_task( # noqa E501 + ws, + ) # do not need send finished task message. yield self._to_DashScopeAPIResponse( - task_id, is_binary, result) + task_id, + is_binary, + result, + ) elif self.ws_stream_mode == WebsocketStreamingMode.IN: # server is in, we send streaming out. await self._send_continue_task_data(ws) - is_binary, result = await self._receive_batch_data_task( # noqa E501 - ws) + ( + is_binary, + result, + ) = await self._receive_batch_data_task( # noqa E501 + ws, + ) # do not need send finished task message. yield self._to_DashScopeAPIResponse( - task_id, is_binary, result) + task_id, + is_binary, + result, + ) elif self.ws_stream_mode == WebsocketStreamingMode.OUT: # we send batch data, server streaming output data. if self.is_binary_input: # send only binary package. data = self.data.get_batch_binary_data() await ws.send_bytes(list(data.values())[0]) - async for is_binary, message in self._receive_streaming_data_task( # noqa E501 - ws): + async for is_binary, message in self._receive_streaming_data_task( # noqa E501 # pylint: disable=line-too-long + ws, + ): yield self._to_DashScopeAPIResponse( - task_id, is_binary, message) + task_id, + is_binary, + message, + ) else: # duplex mode asyncio.create_task(self._send_continue_task_data(ws)) - async for is_binary, message in self._receive_streaming_data_task( # noqa E501 - ws): + async for is_binary, message in self._receive_streaming_data_task( # noqa E501 # pylint: disable=line-too-long + ws, + ): yield self._to_DashScopeAPIResponse( - task_id, is_binary, message) + task_id, + is_binary, + message, + ) except RequestFailure as e: - yield DashScopeAPIResponse(request_id=e.request_id, - status_code=e.http_code, - output=None, - code=e.name, - message=e.message) + yield DashScopeAPIResponse( + request_id=e.request_id, + status_code=e.http_code, + output=None, + code=e.name, + message=e.message, + ) except aiohttp.ClientConnectorError as e: logger.exception(e) - yield DashScopeAPIResponse(request_id='', - status_code=-1, - code='ClientConnectorError', - message=str(e)) + yield DashScopeAPIResponse( + request_id="", + status_code=-1, + code="ClientConnectorError", + message=str(e), + ) except aiohttp.WSServerHandshakeError as e: code = e.status msg = e.message if e.status in [HTTPStatus.FORBIDDEN, HTTPStatus.UNAUTHORIZED]: - msg = 'Unauthorized, your api-key is invalid!' + msg = "Unauthorized, your api-key is invalid!" elif e.status == HTTPStatus.SERVICE_UNAVAILABLE: msg = SERVICE_503_MESSAGE else: pass - yield DashScopeAPIResponse(request_id=task_id, - status_code=code, - code=code, - message=msg) + yield DashScopeAPIResponse( + request_id=task_id, + status_code=code, + code=code, + message=msg, + ) except BaseException as e: logger.exception(e) - yield DashScopeAPIResponse(request_id='', - status_code=-1, - code='Unknown', - message='Error type: %s, message: %s' % - (type(e), e)) + yield DashScopeAPIResponse( + request_id="", + status_code=-1, + code="Unknown", + message=f"Error type: {type(e)}, message: {e}", + ) def _to_DashScopeAPIResponse(self, task_id, is_binary, result): if is_binary: - return DashScopeAPIResponse(request_id=task_id, - status_code=HTTPStatus.OK, - output=result) + return DashScopeAPIResponse( + request_id=task_id, + status_code=HTTPStatus.OK, + output=result, + ) else: # get output and usage. output = {} usage = {} - if 'output' in result: - output = result['output'] - if 'usage' in result: - usage = result['usage'] - return DashScopeAPIResponse(request_id=task_id, - status_code=HTTPStatus.OK, - output=output, - usage=usage) + if "output" in result: + output = result["output"] + if "usage" in result: + usage = result["usage"] + return DashScopeAPIResponse( + request_id=task_id, + status_code=HTTPStatus.OK, + output=output, + usage=usage, + ) async def _receive_streaming_data_task(self, ws): # check if request stream data, re return an iterator, # otherwise we collect data and return user. # no matter what, the response is streaming is_binary_output = False - while True: + while True: # pylint: disable=R1702 msg = await ws.receive() await self._check_websocket_unexpected_message(msg) if msg.type == aiohttp.WSMsgType.TEXT: msg_json = msg.json() - logger.debug('Receive %s event' % msg_json[HEADER][EVENT_KEY]) + logger.debug("Receive %s event", msg_json[HEADER][EVENT_KEY]) if msg_json[HEADER][EVENT_KEY] == EventType.GENERATED: - payload = msg_json['payload'] + payload = msg_json["payload"] yield False, payload elif msg_json[HEADER][EVENT_KEY] == EventType.FINISHED: payload = None - if 'payload' in msg_json: - payload = msg_json['payload'] + if "payload" in msg_json: + payload = msg_json["payload"] logger.debug(payload) if payload: yield False, payload @@ -224,7 +277,7 @@ async def _receive_streaming_data_task(self, ws): elif msg_json[HEADER][EVENT_KEY] == EventType.FAILED: self._on_failed(msg_json) else: - error = 'Receive unknown message: %s' % msg_json + error = f"Receive unknown message: {msg_json}" logger.error(error) raise UnknownMessageReceived(error) elif msg.type == aiohttp.WSMsgType.BINARY: @@ -232,63 +285,71 @@ async def _receive_streaming_data_task(self, ws): yield True, msg.data def _on_failed(self, details): - error = RequestFailure(request_id=details[HEADER][TASK_ID], - http_code=WEBSOCKET_ERROR_CODE, - name=details[HEADER][ERROR_NAME], - message=details[HEADER][ERROR_MESSAGE]) + error = RequestFailure( + request_id=details[HEADER][TASK_ID], + http_code=WEBSOCKET_ERROR_CODE, + name=details[HEADER][ERROR_NAME], + message=details[HEADER][ERROR_MESSAGE], + ) logger.error(error) raise error async def _start_task(self, ws): if self.pre_task_id is not None: - self.task_headers['task_id'] = self.pre_task_id + self.task_headers["task_id"] = self.pre_task_id else: - self.task_headers['task_id'] = uuid.uuid4().hex # create task id. + self.task_headers["task_id"] = uuid.uuid4().hex # create task id. task_header = {**self.task_headers, ACTION_KEY: ActionType.START} # for binary data, the start action has no input, only parameters. start_data = self.data.get_websocket_start_data() message = self._build_up_message(task_header, start_data) - logger.debug('Send start task: {}'.format(message)) + logger.debug("Send start task: %s", message) await ws.send_str(message) async def _send_finished_task(self, ws): task_header = {**self.task_headers, ACTION_KEY: ActionType.FINISHED} - payload = {'input': {}} + payload = {"input": {}} message = self._build_up_message(task_header, payload) - logger.debug('Send finish task: {}'.format(message)) + logger.debug("Send finish task: %s", message) await ws.send_str(message) async def _send_continue_task_data(self, ws): headers = { - 'task_id': self.task_headers['task_id'], - 'action': 'continue-task' + "task_id": self.task_headers["task_id"], + "action": "continue-task", } - for input in self.data.get_websocket_continue_data(): + for input_item in self.data.get_websocket_continue_data(): if self.is_binary_input: - if len(input) > 0: - if isinstance(input, bytes): - await ws.send_bytes(input) + if len(input_item) > 0: + if isinstance(input_item, bytes): + await ws.send_bytes(input_item) logger.debug( - 'Send continue task with bytes: {}'.format( - len(input))) + "Send continue task with bytes: %s", + len(input_item), + ) else: - await ws.send_bytes(list(input.values())[0]) + await ws.send_bytes(list(input_item.values())[0]) logger.debug( - 'Send continue task with list[byte]: {}'.format( - len(input))) + "Send continue task with list[byte]: %s", + len(input_item), + ) else: - if len(input) > 0: - message = self._build_up_message(headers=headers, - payload=input) - logger.debug('Send continue task: {}'.format(message)) + if len(input_item) > 0: + message = self._build_up_message( + headers=headers, + payload=input_item, + ) + logger.debug("Send continue task: %s", message) await ws.send_str(message) await asyncio.sleep(0.000001) # data send completed, and send task completed. await self._send_finished_task(ws) - async def _receive_batch_data_task(self, - ws) -> Tuple[bool, Union[str, bytes]]: + async def _receive_batch_data_task( + self, + ws, + ) -> Tuple[bool, Union[str, bytes]]: """_summary_ Args: @@ -305,17 +366,17 @@ async def _receive_batch_data_task(self, await self._check_websocket_unexpected_message(msg) if msg.type == aiohttp.WSMsgType.TEXT: msg_json = msg.json() - logger.debug('Receive %s event' % msg_json[HEADER][EVENT_KEY]) + logger.debug("Receive %s event", msg_json[HEADER][EVENT_KEY]) if msg_json[HEADER][EVENT_KEY] == EventType.GENERATED: - payload = msg_json['payload'] + payload = msg_json["payload"] return False, payload elif msg_json[HEADER][EVENT_KEY] == EventType.FINISHED: - payload = msg_json['payload'] + payload = msg_json["payload"] return False, payload elif msg_json[HEADER][EVENT_KEY] == EventType.FAILED: self._on_failed(msg_json) else: - error = 'Receive unknown message: %s' % msg_json + error = f"Receive unknown message: {msg_json}" logger.error(error) raise UnknownMessageReceived(error) elif msg.type == aiohttp.WSMsgType.BINARY: @@ -327,36 +388,37 @@ async def _wait_for_task_started(self, ws): await self._check_websocket_unexpected_message(msg) if msg.type == aiohttp.WSMsgType.TEXT: msg_json = msg.json() - logger.debug('Receive %s event' % msg_json[HEADER][EVENT_KEY]) + logger.debug("Receive %s event", msg_json[HEADER][EVENT_KEY]) if msg_json[HEADER][EVENT_KEY] == EventType.STARTED: return elif msg_json[HEADER][EVENT_KEY] == EventType.FAILED: self._on_failed(msg_json) else: raise UnexpectedMessageReceived( - 'Receive unexpected message, expect task-started, real: %s.' # noqa E501 - % msg_json[HEADER][EVENT_KEY]) + "Receive unexpected message, expect task-started, " + f"real: {msg_json[HEADER][EVENT_KEY]}.", + ) elif msg.type == aiohttp.WSMsgType.BINARY: raise UnexpectedMessageReceived( - 'Receive unexpected binary message when wait for task-started' # noqa E501 + "Receive unexpected binary message when wait for task-started", # noqa E501 ) async def _check_websocket_unexpected_message(self, msg): if msg.type == aiohttp.WSMsgType.CLOSED: - details = 'WSMsgType.CLOSE, data: %s, extra: %s' % (msg.data, - msg.extra) - logger.error('Connection unexpected closed!') + details = f"WSMsgType.CLOSE, data: {msg.data}, extra: {msg.extra}" + logger.error("Connection unexpected closed!") raise UnexpectedMessageReceived( - 'Receive unexpected websocket close message, details: %s' % - details) - elif msg.type == aiohttp.WSMsgType.ERROR: - details = 'WSMsgType.ERROR, data: %s, extra: %s' % (msg.data, - msg.extra) - logger.error('Connection error: %s' % details) + f"Receive unexpected websocket close message, " + f"details: {details}", + ) + if msg.type == aiohttp.WSMsgType.ERROR: + details = f"WSMsgType.ERROR, data: {msg.data}, extra: {msg.extra}" + logger.error("Connection error: %s", details) raise UnexpectedMessageReceived( - 'Receive unexpected websocket error message details: %s.' % - details) + f"Receive unexpected websocket error message " + f"details: {details}.", + ) def _build_up_message(self, headers, payload): - message = {'header': headers, 'payload': payload} + message = {"header": headers, "payload": payload} return json.dumps(message) diff --git a/dashscope/app/__init__.py b/dashscope/app/__init__.py index e620cbf..e0f1ba7 100644 --- a/dashscope/app/__init__.py +++ b/dashscope/app/__init__.py @@ -1,5 +1,6 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from .application import Application -__all__ = [Application] +__all__ = ["Application"] diff --git a/dashscope/app/application.py b/dashscope/app/application.py index 1f882a5..78a77a0 100644 --- a/dashscope/app/application.py +++ b/dashscope/app/application.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. """ @File : application.py @@ -12,37 +13,46 @@ from dashscope.app.application_response import ApplicationResponse from dashscope.client.base_api import BaseApi from dashscope.common.api_key import get_default_api_key -from dashscope.common.constants import (DEPRECATED_MESSAGE, HISTORY, MESSAGES, - PROMPT) +from dashscope.common.constants import ( + DEPRECATED_MESSAGE, + HISTORY, + MESSAGES, + PROMPT, +) from dashscope.common.error import InputRequired, InvalidInput from dashscope.common.logging import logger class Application(BaseApi): - task_group = 'apps' - function = 'completion' + task_group = "apps" + function = "completion" """API for app completion calls. """ + class DocReferenceType: - """ doc reference type for rag completion """ + """doc reference type for rag completion""" - simple = 'simple' + simple = "simple" - indexed = 'indexed' + indexed = "indexed" @classmethod - def _validate_params(cls, api_key, app_id): + def _validate_params( # pylint: disable=arguments-renamed + cls, + api_key, + app_id, + ): if api_key is None: api_key = get_default_api_key() if app_id is None or not app_id: - raise InputRequired('App id is required!') + raise InputRequired("App id is required!") return api_key, app_id @classmethod - def call( + def call( # type: ignore[override] cls, app_id: str, prompt: str = None, @@ -50,9 +60,15 @@ def call( workspace: str = None, api_key: str = None, messages: List[Message] = None, - **kwargs - ) -> Union[ApplicationResponse, Generator[ApplicationResponse, None, - None]]: + **kwargs, + ) -> Union[ + ApplicationResponse, + Generator[ + ApplicationResponse, + None, + None, + ], + ]: """Call app completion service. Args: @@ -70,7 +86,7 @@ def call( **kwargs: stream(bool, `optional`): Enable server-sent events - (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 + (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 # pylint: disable=line-too-long the result will back partially[qwen-turbo,bailian-v1]. temperature(float, `optional`): Used to control the degree of randomness and diversity. Specifically, the temperature @@ -88,27 +104,30 @@ def call( tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered[qwen-turbo,bailian-v1]. - top_k(int, `optional`): The size of the sample candidate set when generated. # noqa E501 - For example, when the value is 50, only the 50 highest-scoring tokens # noqa E501 + top_k(int, `optional`): The size of the sample candidate set when generated. # noqa E501 # pylint: disable=line-too-long + For example, when the value is 50, only the 50 highest-scoring tokens # noqa E501 # pylint: disable=line-too-long in a single generation form a randomly sampled candidate set. # noqa E501 The larger the value, the higher the randomness generated; # noqa E501 the smaller the value, the higher the certainty generated. # noqa E501 The default value is 0, which means the top_k policy is # noqa E501 not enabled. At this time, only the top_p policy takes effect. # noqa E501 - seed(int, `optional`): When generating, the seed of the random number is used to control the - randomness of the model generation. If you use the same seed, each run will generate the same results; - you can use the same seed when you need to reproduce the model's generated results. + seed( + int, + `optional` + ): When generating, the seed of the random number is used to control the + randomness of the model generation. If you use the same seed, each run will generate the same results; # pylint: disable=line-too-long + you can use the same seed when you need to reproduce the model's generated results. # pylint: disable=line-too-long The seed parameter supports unsigned 64-bit integer types. Default value 1234 session_id(str, `optional`): Session if for multiple rounds call. biz_params(dict, `optional`): The extra parameters for flow or plugin. - has_thoughts(bool, `optional`): Flag to return rag or plugin process details. Default value false. + has_thoughts(bool, `optional`): Flag to return rag or plugin process details. Default value false. # pylint: disable=line-too-long doc_tag_codes(list[str], `optional`): Tag code list for doc retrival. doc_reference_type(str, `optional`): The type of doc reference. - simple: simple format of doc retrival which not include index in response text but in doc reference list. + simple: simple format of doc retrival which not include index in response text but in doc reference list. # pylint: disable=line-too-long indexed: include both index in response text and doc reference list - memory_id(str, `optional`): Used to store long term context summary between end users and assistant. + memory_id(str, `optional`): Used to store long term context summary between end users and assistant. # pylint: disable=line-too-long image_list(list, `optional`): Used to pass image url list. - rag_options(dict, `optional`): Rag options for retrieval augmented generation options. + rag_options(dict, `optional`): Rag options for retrieval augmented generation options. # pylint: disable=line-too-long Raises: InvalidInput: The history and auto_history are mutually exclusive. @@ -120,45 +139,62 @@ def call( api_key, app_id = Application._validate_params(api_key, app_id) - if (prompt is None or not prompt) and (messages is None - or len(messages) == 0): - raise InputRequired('prompt or messages is required!') + if (prompt is None or not prompt) and ( + messages is None or len(messages) == 0 + ): + raise InputRequired("prompt or messages is required!") if workspace is not None and workspace: - headers = kwargs.pop('headers', {}) - headers['X-DashScope-WorkSpace'] = workspace - kwargs['headers'] = headers - - input, parameters = cls._build_input_parameters( - prompt, history, messages, **kwargs) - request = _build_api_request(model='', - input=input, - task_group=Application.task_group, - task=app_id, - function=Application.function, - workspace=workspace, - api_key=api_key, - is_service=False, - **parameters) + headers = kwargs.pop("headers", {}) + headers["X-DashScope-WorkSpace"] = workspace + kwargs["headers"] = headers + + ( + input, # pylint: disable=redefined-builtin + parameters, + ) = cls._build_input_parameters( + prompt, + history, + messages, + **kwargs, + ) + request = _build_api_request( + model="", + input=input, + task_group=Application.task_group, + task=app_id, + function=Application.function, + workspace=workspace, + api_key=api_key, + is_service=False, + **parameters, + ) # call request service. response = request.call() - is_stream = kwargs.get('stream', False) + is_stream = kwargs.get("stream", False) if is_stream: - return (ApplicationResponse.from_api_response(rsp) - for rsp in response) + return ( + ApplicationResponse.from_api_response(rsp) for rsp in response + ) else: return ApplicationResponse.from_api_response(response) @classmethod - def _build_input_parameters(cls, prompt, history, messages, **kwargs): + def _build_input_parameters( # pylint: disable=too-many-branches + cls, + prompt, + history, + messages, + **kwargs, + ): parameters = {} input_param = {} if messages is not None: msgs = copy.deepcopy(messages) if prompt is not None and prompt: - msgs.append({'role': Role.USER, 'content': prompt}) + msgs.append({"role": Role.USER, "content": prompt}) input_param = {MESSAGES: msgs} elif history is not None and history: logger.warning(DEPRECATED_MESSAGE) @@ -168,36 +204,37 @@ def _build_input_parameters(cls, prompt, history, messages, **kwargs): else: input_param[PROMPT] = prompt - session_id = kwargs.pop('session_id', None) + session_id = kwargs.pop("session_id", None) if session_id is not None and session_id: - input_param['session_id'] = session_id + input_param["session_id"] = session_id - doc_reference_type = kwargs.pop('doc_reference_type', None) + doc_reference_type = kwargs.pop("doc_reference_type", None) if doc_reference_type is not None and doc_reference_type: - input_param['doc_reference_type'] = doc_reference_type + input_param["doc_reference_type"] = doc_reference_type - doc_tag_codes = kwargs.pop('doc_tag_codes', None) + doc_tag_codes = kwargs.pop("doc_tag_codes", None) if doc_tag_codes is not None: if isinstance(doc_tag_codes, list) and all( - isinstance(item, str) for item in doc_tag_codes): - input_param['doc_tag_codes'] = doc_tag_codes + isinstance(item, str) for item in doc_tag_codes + ): + input_param["doc_tag_codes"] = doc_tag_codes else: - raise InvalidInput('doc_tag_codes is not a List[str]') + raise InvalidInput("doc_tag_codes is not a List[str]") - memory_id = kwargs.pop('memory_id', None) + memory_id = kwargs.pop("memory_id", None) if memory_id is not None: - input_param['memory_id'] = memory_id + input_param["memory_id"] = memory_id - biz_params = kwargs.pop('biz_params', None) + biz_params = kwargs.pop("biz_params", None) if biz_params is not None and biz_params: - input_param['biz_params'] = biz_params + input_param["biz_params"] = biz_params - image_list = kwargs.pop('image_list', None) + image_list = kwargs.pop("image_list", None) if image_list is not None and image_list: - input_param['image_list'] = image_list + input_param["image_list"] = image_list - file_list = kwargs.pop('file_list', None) + file_list = kwargs.pop("file_list", None) if file_list is not None and file_list: - input_param['file_list'] = file_list + input_param["file_list"] = file_list return input_param, {**parameters, **kwargs} diff --git a/dashscope/app/application_response.py b/dashscope/app/application_response.py index a2bf32e..19af166 100644 --- a/dashscope/app/application_response.py +++ b/dashscope/app/application_response.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. """ @File : application_response.py @@ -6,10 +7,12 @@ """ from dataclasses import dataclass from http import HTTPStatus -from typing import Dict, List, Optional +from typing import Dict, List -from dashscope.api_entities.dashscope_response import (DashScopeAPIResponse, - DictMixin) +from dashscope.api_entities.dashscope_response import ( + DashScopeAPIResponse, + DictMixin, +) @dataclass(init=False) @@ -23,22 +26,24 @@ class ApplicationThought(DictMixin): action_input: Dict observation: str - def __init__(self, - thought: str = None, - action_type: str = None, - response: str = None, - action_name: str = None, - action: str = None, - action_input_stream: str = None, - action_input: Dict = None, - observation: str = None, - **kwargs): - """ Thought of app completion call result which describe model planning and doc retrieval + def __init__( + self, + thought: str = None, + action_type: str = None, + response: str = None, + action_name: str = None, + action: str = None, + action_input_stream: str = None, + action_input: Dict = None, + observation: str = None, + **kwargs, + ): + """Thought of app completion call result which describe model planning and doc retrieval # noqa: E501 # pylint: disable=line-too-long or plugin calls details. Args: thought (str, optional): Model's inference thought for doc retrieval or plugin process. - action_type (str, optional): Action type. response : final response; api: to run api calls. + action_type (str, optional): Action type. response : final response; api: to run api calls. # pylint: disable=line-too-long response (str, optional): Model's results. action_name (str, optional): Action name, e.g. searchDocument、api. action (str, optional): Code of action, means which plugin or action to be run. @@ -47,15 +52,17 @@ def __init__(self, observation (str, optional): Result of api call or doc retrieval. """ - super().__init__(thought=thought, - action_type=action_type, - response=response, - action_name=action_name, - action=action, - action_input_stream=action_input_stream, - action_input=action_input, - observation=observation, - **kwargs) + super().__init__( + thought=thought, + action_type=action_type, + response=response, + action_name=action_name, + action=action, + action_input_stream=action_input_stream, + action_input=action_input, + observation=observation, + **kwargs, + ) @dataclass(init=False) @@ -70,40 +77,45 @@ class ApplicationDocReference(DictMixin): images: List[str] page_number: List[int] - def __init__(self, - index_id: str = None, - title: str = None, - doc_id: str = None, - doc_name: str = None, - doc_url: str = None, - text: str = None, - biz_id: str = None, - images: List[str] = None, - page_number: List[int] = None, - **kwargs): - """ Doc references for retrieval result. + def __init__( + self, + index_id: str = None, + title: str = None, + doc_id: str = None, + doc_name: str = None, + doc_url: str = None, + text: str = None, + biz_id: str = None, + images: List[str] = None, + page_number: List[int] = None, + **kwargs, + ): + """Doc references for retrieval result. Args: - index_id (str, optional): Index id of doc retrival result reference. + index_id (str, optional): Index id of doc retrival result reference. # noqa: E501 title (str, optional): Title of original doc that retrieved. doc_id (str, optional): Id of original doc that retrieved. doc_name (str, optional): Name of original doc that retrieved. doc_url (str, optional): Url of original doc that retrieved. text (str, optional): Text in original doc that retrieved. - biz_id (str, optional): Biz id that caller is able to associated for biz logic. + biz_id (str, optional): Biz id that caller is able to associated for biz logic. # noqa: E501 # pylint: disable=line-too-long images (list, optional): List of referenced image URLs """ - super().__init__(index_id=index_id, - title=title, - doc_id=doc_id, - doc_name=doc_name, - doc_url=doc_url, - text=text, - biz_id=biz_id, - images=images, - page_number=page_number, - **kwargs) + super().__init__( + index_id=index_id, + title=title, + doc_id=doc_id, + doc_name=doc_name, + doc_url=doc_url, + text=text, + biz_id=biz_id, + images=images, + page_number=page_number, + **kwargs, + ) + @dataclass(init=False) class WorkflowMessage(DictMixin): @@ -119,16 +131,18 @@ class Message(DictMixin): role: str content: str - def __init__(self, - node_id: str = None, - node_name: str = None, - node_type: str = None, - node_status: str = None, - node_is_completed: str = None, - node_msg_seq_id: int = None, - message: Message = None, - **kwargs): - """ Workflow message. + def __init__( + self, + node_id: str = None, + node_name: str = None, + node_type: str = None, + node_status: str = None, + node_is_completed: str = None, + node_msg_seq_id: int = None, + message: Message = None, + **kwargs, + ): + """Workflow message. Args: node_id (str, optional): . @@ -140,14 +154,16 @@ def __init__(self, message (Message, optional): . """ - super().__init__(node_id=node_id, - node_name=node_name, - node_type=node_type, - node_status=node_status, - node_is_completed=node_is_completed, - node_msg_seq_id=node_msg_seq_id, - message=message, - **kwargs) + super().__init__( + node_id=node_id, + node_name=node_name, + node_type=node_type, + node_status=node_status, + node_is_completed=node_is_completed, + node_msg_seq_id=node_msg_seq_id, + message=message, + **kwargs, + ) @dataclass(init=False) @@ -159,15 +175,16 @@ class ApplicationOutput(DictMixin): doc_references: List[ApplicationDocReference] workflow_message: WorkflowMessage - def __init__(self, - text: str = None, - finish_reason: str = None, - session_id: str = None, - thoughts: List[ApplicationThought] = None, - doc_references: List[ApplicationDocReference] = None, - workflow_message: WorkflowMessage = None, - **kwargs): - + def __init__( + self, + text: str = None, + finish_reason: str = None, + session_id: str = None, + thoughts: List[ApplicationThought] = None, + doc_references: List[ApplicationDocReference] = None, + workflow_message: WorkflowMessage = None, + **kwargs, + ): ths = None if thoughts is not None: ths = [] @@ -180,13 +197,15 @@ def __init__(self, for ref in doc_references: refs.append(ApplicationDocReference(**ref)) - super().__init__(text=text, - finish_reason=finish_reason, - session_id=session_id, - thoughts=ths, - doc_references=refs, - workflow_message=workflow_message, - **kwargs) + super().__init__( + text=text, + finish_reason=finish_reason, + session_id=session_id, + thoughts=ths, + doc_references=refs, + workflow_message=workflow_message, + **kwargs, + ) @dataclass(init=False) @@ -195,15 +214,19 @@ class ApplicationModelUsage(DictMixin): input_tokens: int output_tokens: int - def __init__(self, - model_id: str = None, - input_tokens: int = 0, - output_tokens: int = 0, - **kwargs): - super().__init__(model_id=model_id, - input_tokens=input_tokens, - output_tokens=output_tokens, - **kwargs) + def __init__( + self, + model_id: str = None, + input_tokens: int = 0, + output_tokens: int = 0, + **kwargs, + ): + super().__init__( + model_id=model_id, + input_tokens=input_tokens, + output_tokens=output_tokens, + **kwargs, + ) @dataclass(init=False) @@ -238,9 +261,12 @@ def from_api_response(api_response: DashScopeAPIResponse): code=api_response.code, message=api_response.message, output=ApplicationOutput(**api_response.output), - usage=ApplicationUsage(**usage)) + usage=ApplicationUsage(**usage), + ) else: - return ApplicationResponse(status_code=api_response.status_code, - request_id=api_response.request_id, - code=api_response.code, - message=api_response.message) + return ApplicationResponse( + status_code=api_response.status_code, + request_id=api_response.request_id, + code=api_response.code, + message=api_response.message, + ) diff --git a/dashscope/assistants/__init__.py b/dashscope/assistants/__init__.py index d911fb2..61e91da 100644 --- a/dashscope/assistants/__init__.py +++ b/dashscope/assistants/__init__.py @@ -1,16 +1,19 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. # yapf: disable -from dashscope.assistants.assistant_types import (Assistant, AssistantFile, - AssistantList, - DeleteResponse) +from dashscope.assistants.assistant_types import ( + Assistant, AssistantFile, + AssistantList, + DeleteResponse, +) from dashscope.assistants.assistants import Assistants __all__ = [ - Assistant, - Assistants, - AssistantList, - AssistantFile, - DeleteResponse, + 'Assistant', + 'Assistants', + 'AssistantList', + 'AssistantFile', + 'DeleteResponse', ] diff --git a/dashscope/assistants/assistant_types.py b/dashscope/assistants/assistant_types.py index a862f93..049a124 100644 --- a/dashscope/assistants/assistant_types.py +++ b/dashscope/assistants/assistant_types.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. # adapter from openai sdk @@ -7,9 +8,16 @@ from dashscope.common.base_type import BaseList, BaseObjectMixin __all__ = [ - 'Assistant', 'AssistantFile', 'ToolCodeInterpreter', 'ToolSearch', - 'ToolWanX', 'FunctionDefinition', 'ToolFunction', 'AssistantFileList', - 'AssistantList', 'DeleteResponse' + "Assistant", + "AssistantFile", + "ToolCodeInterpreter", + "ToolSearch", + "ToolWanX", + "FunctionDefinition", + "ToolFunction", + "AssistantFileList", + "AssistantList", + "DeleteResponse", ] @@ -20,31 +28,31 @@ class AssistantFile(BaseObjectMixin): created_at: int object: str - def __init__(self, **kwargs): + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) @dataclass(init=False) class ToolCodeInterpreter(BaseObjectMixin): - type: str = 'code_interpreter' + type: str = "code_interpreter" - def __init__(self, **kwargs): + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) @dataclass(init=False) class ToolSearch(BaseObjectMixin): - type: str = 'search' + type: str = "search" - def __init__(self, **kwargs): + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) @dataclass(init=False) class ToolWanX(BaseObjectMixin): - type: str = 'wanx' + type: str = "wanx" - def __init__(self, **kwargs): + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) @@ -54,34 +62,34 @@ class FunctionDefinition(BaseObjectMixin): description: Optional[str] = None parameters: Optional[Dict[str, object]] = None - def __init__(self, **kwargs): + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) @dataclass(init=False) class ToolFunction(BaseObjectMixin): function: FunctionDefinition - type: str = 'function' + type: str = "function" def __init__(self, **kwargs): - self.function = FunctionDefinition(**kwargs.pop('function', {})) + self.function = FunctionDefinition(**kwargs.pop("function", {})) super().__init__(**kwargs) Tool = Union[ToolCodeInterpreter, ToolSearch, ToolFunction, ToolWanX] ASSISTANT_SUPPORT_TOOL = { - 'code_interpreter': ToolCodeInterpreter, - 'search': ToolSearch, - 'wanx': ToolWanX, - 'function': ToolFunction + "code_interpreter": ToolCodeInterpreter, + "search": ToolSearch, + "wanx": ToolWanX, + "function": ToolFunction, } def convert_tools_dict_to_objects(tools): tools_object = [] for tool in tools: - if 'type' in tool: - tool_type = ASSISTANT_SUPPORT_TOOL.get(tool['type'], None) + if "type" in tool: + tool_type = ASSISTANT_SUPPORT_TOOL.get(tool["type"], None) if tool_type: tools_object.append(tool_type(**tool)) else: @@ -107,16 +115,16 @@ class Assistant(BaseObjectMixin): """ model: str name: Optional[str] = None - created_at: int + created_at: int # type: ignore[misc] """The Unix timestamp (in seconds) for when the assistant was created. """ description: Optional[str] = None - file_ids: List[str] + file_ids: List[str] # type: ignore[misc] instructions: Optional[str] = None metadata: Optional[object] = None - tools: List[Tool] + tools: List[Tool] # type: ignore[misc] object: Optional[str] = None @@ -128,7 +136,7 @@ class Assistant(BaseObjectMixin): request_id: Optional[str] = None def __init__(self, **kwargs): - self.tools = convert_tools_dict_to_objects(kwargs.pop('tools', [])) + self.tools = convert_tools_dict_to_objects(kwargs.pop("tools", [])) super().__init__(**kwargs) @@ -136,34 +144,44 @@ def __init__(self, **kwargs): class AssistantList(BaseList): data: List[Assistant] - def __init__(self, - has_more: bool = None, - last_id: Optional[str] = None, - first_id: Optional[str] = None, - data: List[Assistant] = [], - **kwargs): - super().__init__(has_more=has_more, - last_id=last_id, - first_id=first_id, - data=data, - **kwargs) + # pylint: disable=dangerous-default-value + def __init__( + self, + has_more: bool = None, + last_id: Optional[str] = None, + first_id: Optional[str] = None, + data: List[Assistant] = [], + **kwargs, + ): + super().__init__( + has_more=has_more, + last_id=last_id, + first_id=first_id, + data=data, + **kwargs, + ) @dataclass(init=False) class AssistantFileList(BaseList): data: List[AssistantFile] - def __init__(self, - has_more: bool = None, - last_id: Optional[str] = None, - first_id: Optional[str] = None, - data: List[AssistantFile] = [], - **kwargs): - super().__init__(has_more=has_more, - last_id=last_id, - first_id=first_id, - data=data, - **kwargs) + # pylint: disable=dangerous-default-value + def __init__( + self, + has_more: bool = None, + last_id: Optional[str] = None, + first_id: Optional[str] = None, + data: List[AssistantFile] = [], + **kwargs, + ): + super().__init__( + has_more=has_more, + last_id=last_id, + first_id=first_id, + data=data, + **kwargs, + ) @dataclass(init=False) @@ -171,5 +189,5 @@ class DeleteResponse(BaseObjectMixin): id: str deleted: bool - def __init__(self, **kwargs): + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) diff --git a/dashscope/assistants/assistants.py b/dashscope/assistants/assistants.py index 56d57a5..7dd623d 100644 --- a/dashscope/assistants/assistants.py +++ b/dashscope/assistants/assistants.py @@ -1,22 +1,38 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Dict, List, Optional -from dashscope.assistants.assistant_types import (Assistant, AssistantList, - DeleteResponse) -from dashscope.client.base_api import (CancelMixin, CreateMixin, DeleteMixin, - GetStatusMixin, ListObjectMixin, - UpdateMixin) +from dashscope.assistants.assistant_types import ( + Assistant, + AssistantList, + DeleteResponse, +) +from dashscope.client.base_api import ( + CancelMixin, + CreateMixin, + DeleteMixin, + GetStatusMixin, + ListObjectMixin, + UpdateMixin, +) from dashscope.common.error import ModelRequired -__all__ = ['Assistants'] +__all__ = ["Assistants"] -class Assistants(CreateMixin, CancelMixin, DeleteMixin, ListObjectMixin, - GetStatusMixin, UpdateMixin): - SUB_PATH = 'assistants' +class Assistants( + CreateMixin, + CancelMixin, + DeleteMixin, + ListObjectMixin, + GetStatusMixin, + UpdateMixin, +): + SUB_PATH = "assistants" @classmethod + # pylint: disable=dangerous-default-value def _create_assistant_object( cls, model: str = None, @@ -33,53 +49,63 @@ def _create_assistant_object( ): obj = {} if model: - obj['model'] = model + obj["model"] = model if name: - obj['name'] = name + obj["name"] = name if description: - obj['description'] = description + obj["description"] = description if instructions: - obj['instructions'] = instructions + obj["instructions"] = instructions if tools is not None: - obj['tools'] = tools - obj['file_ids'] = file_ids - obj['metadata'] = metadata + obj["tools"] = tools + obj["file_ids"] = file_ids + obj["metadata"] = metadata if top_p is not None: - obj['top_p'] = top_p + obj["top_p"] = top_p if top_k is not None: - obj['top_k'] = top_k + obj["top_k"] = top_k if temperature is not None: - obj['temperature'] = temperature + obj["temperature"] = temperature if max_tokens is not None: - obj['max_tokens'] = max_tokens + obj["max_tokens"] = max_tokens return obj @classmethod - def call(cls, - *, - model: str, - name: str = None, - description: str = None, - instructions: str = None, - tools: Optional[List[Dict]] = None, - file_ids: Optional[List[str]] = [], - metadata: Dict = None, - workspace: str = None, - api_key: str = None, - **kwargs) -> Assistant: + # pylint: disable=dangerous-default-value + def call( # type: ignore[override] + cls, + *, + model: str, + name: str = None, + description: str = None, + instructions: str = None, + tools: Optional[List[Dict]] = None, + file_ids: Optional[List[str]] = [], + metadata: Dict = None, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> Assistant: """Create Assistant. Args: model (str): The model to use. name (str, optional): The assistant name. Defaults to None. - description (str, optional): The assistant description. Defaults to None. - instructions (str, optional): The system instructions this assistant uses. Defaults to None. - tools (Optional[List[Dict]], optional): List of tools to use. Defaults to []. - file_ids (Optional[List[str]], optional): : The files to use. Defaults to []. - metadata (Dict, optional): Custom key-value pairs associate with assistant. Defaults to None. - workspace (str, optional): The DashScope workspace id. Defaults to None. + description (str, optional): + The assistant description. Defaults to None. + instructions (str, optional): + The system instructions this assistant uses. Defaults to None. + tools (Optional[List[Dict]], optional): + List of tools to use. Defaults to []. + file_ids (Optional[List[str]], optional): + : The files to use. Defaults to []. + metadata (Dict, optional): + Custom key-value pairs associate with assistant. Defaults to + None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): The DashScope api key. Defaults to None. Raises: @@ -88,50 +114,65 @@ def call(cls, Returns: Assistant: The `Assistant` object. """ - return cls.create(model=model, - name=name, - description=description, - instructions=instructions, - tools=tools, - file_ids=file_ids, - metadata=metadata, - workspace=workspace, - api_key=api_key, - **kwargs) + return cls.create( + model=model, + name=name, + description=description, + instructions=instructions, + tools=tools, + file_ids=file_ids, + metadata=metadata, + workspace=workspace, + api_key=api_key, + **kwargs, + ) @classmethod - def create(cls, - *, - model: str, - name: str = None, - description: str = None, - instructions: str = None, - tools: Optional[List[Dict]] = None, - file_ids: Optional[List[str]] = [], - metadata: Dict = None, - workspace: str = None, - api_key: str = None, - top_p: Optional[float] = None, - top_k: Optional[int] = None, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, - **kwargs) -> Assistant: + # pylint: disable=dangerous-default-value + def create( + cls, + *, + model: str, + name: str = None, + description: str = None, + instructions: str = None, + tools: Optional[List[Dict]] = None, + file_ids: Optional[List[str]] = [], + metadata: Dict = None, + workspace: str = None, + api_key: str = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + **kwargs, + ) -> Assistant: """Create Assistant. Args: model (str): The model to use. name (str, optional): The assistant name. Defaults to None. - description (str, optional): The assistant description. Defaults to None. - instructions (str, optional): The system instructions this assistant uses. Defaults to None. - tools (Optional[List[Dict]], optional): List of tools to use. Defaults to []. - file_ids (Optional[List[str]], optional): : The files to use. Defaults to []. - metadata (Dict, optional): Custom key-value pairs associate with assistant. Defaults to None. - workspace (str, optional): The DashScope workspace id. Defaults to None. + description (str, optional): + The assistant description. Defaults to None. + instructions (str, optional): + The system instructions this assistant uses. Defaults to None. + tools (Optional[List[Dict]], optional): + List of tools to use. Defaults to []. + file_ids (Optional[List[str]], optional): + : The files to use. Defaults to []. + metadata (Dict, optional): + Custom key-value pairs associate with assistant. Defaults to + None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): The DashScope api key. Defaults to None. - top_p (float, optional): top_p parameter for model. Defaults to None. + top_p (float, optional): + top_p parameter for model. Defaults to None. top_k (int, optional): top_p parameter for model. Defaults to None. - temperature (float, optional): temperature parameter for model. Defaults to None. - max_tokens (int, optional): max_tokens parameter for model. Defaults to None. + temperature (float, optional): + temperature parameter for model. Defaults to None. + max_tokens (int, optional): + max_tokens parameter for model. Defaults to None. Raises: ModelRequired: The model is required. @@ -140,24 +181,38 @@ def create(cls, Assistant: The `Assistant` object. """ if not model: - raise ModelRequired('Model is required!') - data = cls._create_assistant_object(model, name, description, - instructions, tools, file_ids, - metadata, top_p, top_k, temperature, max_tokens) - response = super().call(data=data, - api_key=api_key, - flattened_output=True, - workspace=workspace, - **kwargs) + raise ModelRequired("Model is required!") + data = cls._create_assistant_object( + model, + name, + description, + instructions, + tools, + file_ids, + metadata, + top_p, + top_k, + temperature, + max_tokens, + ) + response = super().call( + data=data, + api_key=api_key, + flattened_output=True, + workspace=workspace, + **kwargs, + ) return Assistant(**response) @classmethod - def retrieve(cls, - assistant_id: str, - *, - workspace: str = None, - api_key: str = None, - **kwargs) -> Assistant: + def retrieve( + cls, + assistant_id: str, + *, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> Assistant: """Get the `Assistant`. Args: @@ -168,18 +223,22 @@ def retrieve(cls, Returns: Assistant: The `Assistant` object. """ - return cls.get(assistant_id, - workspace=workspace, - api_key=api_key, - **kwargs) + return cls.get( + assistant_id, + workspace=workspace, + api_key=api_key, + **kwargs, + ) @classmethod - def get(cls, - assistant_id: str, - *, - workspace: str = None, - api_key: str = None, - **kwargs) -> Assistant: + def get( # type: ignore[override] + cls, + assistant_id: str, + *, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> Assistant: """Get the `Assistant`. Args: @@ -191,106 +250,138 @@ def get(cls, Assistant: The `Assistant` object. """ if not assistant_id: - raise ModelRequired('assistant_id is required!') - response = super().get(assistant_id, - workspace=workspace, - api_key=api_key, - flattened_output=True, - **kwargs) + raise ModelRequired("assistant_id is required!") + response = super().get( + assistant_id, + workspace=workspace, + api_key=api_key, + flattened_output=True, + **kwargs, + ) return Assistant(**response) @classmethod - def list(cls, - *, - limit: int = None, - order: str = None, - after: str = None, - before: str = None, - workspace: str = None, - api_key: str = None, - **kwargs) -> AssistantList: + def list( # type: ignore[override] + cls, + *, + limit: int = None, + order: str = None, + after: str = None, + before: str = None, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> AssistantList: """List assistants Args: - limit (int, optional): How many assistant to retrieve. Defaults to None. + limit (int, optional): + How many assistant to retrieve. Defaults to None. order (str, optional): Sort order by created_at. Defaults to None. after (str, optional): Assistant id after. Defaults to None. before (str, optional): Assistant id before. Defaults to None. - workspace (str, optional): The DashScope workspace id. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): Your DashScope api key. Defaults to None. Returns: AssistantList: The list of assistants. """ - response = super().list(limit=limit, - order=order, - after=after, - before=before, - workspace=workspace, - api_key=api_key, - flattened_output=True, - **kwargs) + response = super().list( + limit=limit, + order=order, + after=after, + before=before, + workspace=workspace, + api_key=api_key, + flattened_output=True, + **kwargs, + ) return AssistantList(**response) @classmethod - def update(cls, - assistant_id: str, - *, - model: str = None, - name: str = None, - description: str = None, - instructions: str = None, - tools: Optional[List[Dict]] = None, - file_ids: Optional[List[str]] = [], - metadata: Dict = None, - workspace: str = None, - api_key: str = None, - top_p: Optional[float] = None, - top_k: Optional[int] = None, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, - **kwargs) -> Assistant: + # pylint: disable=dangerous-default-value + def update( # type: ignore[override] + cls, + assistant_id: str, + *, + model: str = None, + name: str = None, + description: str = None, + instructions: str = None, + tools: Optional[List[Dict]] = None, + file_ids: Optional[List[str]] = [], + metadata: Dict = None, + workspace: str = None, + api_key: str = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + **kwargs, + ) -> Assistant: """Update an exist assistants Args: assistant_id (str): The target assistant id. model (str): The model to use. name (str, optional): The assistant name. Defaults to None. - description (str, optional): The assistant description . Defaults to None. - instructions (str, optional): The system instructions this assistant uses.. Defaults to None. - tools (Optional[str], optional): List of tools to use.. Defaults to []. - file_ids (Optional[str], optional): The files to use in assistants.. Defaults to []. - metadata (Dict, optional): Custom key-value pairs associate with assistant. Defaults to None. + description (str, optional): + The assistant description . Defaults to None. + instructions (str, optional): The system instructions this assistant uses.. Defaults to None. # noqa: E501 # pylint: disable=line-too-long + tools (Optional[str], optional): List of tools to use.. Defaults to []. # noqa: E501 + file_ids (Optional[str], optional): The files to use in assistants.. Defaults to []. # noqa: E501 # pylint: disable=line-too-long + metadata (Dict, optional): + Custom key-value pairs associate with assistant. Defaults to + None. workspace (str): The DashScope workspace id. - api_key (str, optional): The DashScope workspace id. Defaults to None. - top_p (float, optional): top_p parameter for model. Defaults to None. + api_key (str, optional): + The DashScope workspace id. Defaults to None. + top_p (float, optional): + top_p parameter for model. Defaults to None. top_k (int, optional): top_p parameter for model. Defaults to None. - temperature (float, optional): temperature parameter for model. Defaults to None. - max_tokens (int, optional): max_tokens parameter for model. Defaults to None. + temperature (float, optional): + temperature parameter for model. Defaults to None. + max_tokens (int, optional): + max_tokens parameter for model. Defaults to None. Returns: Assistant: The updated assistant. """ if not assistant_id: - raise ModelRequired('assistant_id is required!') - response = super().update(assistant_id, - cls._create_assistant_object( - model, name, description, instructions, - tools, file_ids, metadata, top_p, top_k, temperature, max_tokens), - api_key=api_key, - workspace=workspace, - flattened_output=True, - method='post', - **kwargs) + raise ModelRequired("assistant_id is required!") + response = super().update( + assistant_id, + cls._create_assistant_object( + model, + name, + description, + instructions, + tools, + file_ids, + metadata, + top_p, + top_k, + temperature, + max_tokens, + ), + api_key=api_key, + workspace=workspace, + flattened_output=True, + method="post", + **kwargs, + ) return Assistant(**response) @classmethod - def delete(cls, - assistant_id: str, - *, - workspace: str = None, - api_key: str = None, - **kwargs) -> DeleteResponse: + def delete( # type: ignore[override] + cls, + assistant_id: str, + *, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> DeleteResponse: """Delete uploaded file. Args: @@ -302,10 +393,12 @@ def delete(cls, AssistantsDeleteResponse: Delete result. """ if not assistant_id: - raise ModelRequired('assistant_id is required!') - response = super().delete(assistant_id, - api_key=api_key, - workspace=workspace, - flattened_output=True, - **kwargs) + raise ModelRequired("assistant_id is required!") + response = super().delete( + assistant_id, + api_key=api_key, + workspace=workspace, + flattened_output=True, + **kwargs, + ) return DeleteResponse(**response) diff --git a/dashscope/assistants/files.py b/dashscope/assistants/files.py index c2faa38..451c671 100644 --- a/dashscope/assistants/files.py +++ b/dashscope/assistants/files.py @@ -1,33 +1,43 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Optional -from dashscope.assistants.assistant_types import (AssistantFile, - AssistantFileList, - DeleteResponse) -from dashscope.client.base_api import (CreateMixin, DeleteMixin, - GetStatusMixin, ListObjectMixin) +from dashscope.assistants.assistant_types import ( + AssistantFile, + AssistantFileList, + DeleteResponse, +) +from dashscope.client.base_api import ( + CreateMixin, + DeleteMixin, + GetStatusMixin, + ListObjectMixin, +) from dashscope.common.error import InputRequired -__all__ = ['Files'] +__all__ = ["Files"] class Files(CreateMixin, DeleteMixin, ListObjectMixin, GetStatusMixin): - SUB_PATH = 'assistants' + SUB_PATH = "assistants" @classmethod - def call(cls, - assistant_id: str, - *, - file_id: str, - workspace: str = None, - api_key: str = None, - **kwargs) -> AssistantFile: + def call( # type: ignore[override] + cls, + assistant_id: str, + *, + file_id: str, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> AssistantFile: """Create assistant file. Args: assistant_id (str): The target assistant id. file_id (str): The file id. - workspace (str, optional): The DashScope workspace id. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): The DashScope api key. Defaults to None. Raises: @@ -36,26 +46,31 @@ def call(cls, Returns: AssistantFile: The assistant file object. """ - return cls.create(assistant_id, - file_id=file_id, - workspace=workspace, - api_key=api_key, - **kwargs) + return cls.create( + assistant_id, + file_id=file_id, + workspace=workspace, + api_key=api_key, + **kwargs, + ) @classmethod - def create(cls, - assistant_id: str, - *, - file_id: str, - workspace: str = None, - api_key: str = None, - **kwargs) -> AssistantFile: + def create( + cls, + assistant_id: str, + *, + file_id: str, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> AssistantFile: """Create assistant file. Args: assistant_id (str): The target assistant id. file_id (str): The file id. - workspace (str, optional): The DashScope workspace id. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): The DashScope api key. Defaults to None. Raises: @@ -65,122 +80,145 @@ def create(cls, AssistantFile: _description_ """ if not file_id or not assistant_id: - raise InputRequired('input file_id and assistant_id is required!') - - response = super().call(data={'file_id': file_id}, - path=f'assistants/{assistant_id}/files', - api_key=api_key, - flattened_output=True, - workspace=workspace, - **kwargs) + raise InputRequired("input file_id and assistant_id is required!") + + response = super().call( + data={"file_id": file_id}, + path=f"assistants/{assistant_id}/files", + api_key=api_key, + flattened_output=True, + workspace=workspace, + **kwargs, + ) return AssistantFile(**response) @classmethod - def list(cls, - assistant_id: str, - *, - limit: int = None, - order: str = None, - after: str = None, - before: str = None, - workspace: str = None, - api_key: str = None, - **kwargs) -> AssistantFileList: + def list( # type: ignore[override] + cls, + assistant_id: str, + *, + limit: int = None, + order: str = None, + after: str = None, + before: str = None, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> AssistantFileList: """List assistant files. Args: assistant_id (str): The assistant id. - limit (int, optional): How many assistant to retrieve. Defaults to None. + limit (int, optional): + How many assistant to retrieve. Defaults to None. order (str, optional): Sort order by created_at. Defaults to None. after (str, optional): Assistant id after. Defaults to None. before (str, optional): Assistant id before. Defaults to None. - workspace (str, optional): The DashScope workspace id. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): Your DashScope api key. Defaults to None. Returns: ListAssistantFile: The list of file objects. """ - response = super().list(limit=limit, - order=order, - after=after, - before=before, - path=f'assistants/{assistant_id}/files', - api_key=api_key, - flattened_output=True, - workspace=workspace, - **kwargs) + response = super().list( + limit=limit, + order=order, + after=after, + before=before, + path=f"assistants/{assistant_id}/files", + api_key=api_key, + flattened_output=True, + workspace=workspace, + **kwargs, + ) return AssistantFileList(**response) @classmethod - def retrieve(cls, - file_id: str, - *, - assistant_id: str, - workspace: str = None, - api_key: str = None, - **kwargs) -> AssistantFile: + def retrieve( + cls, + file_id: str, + *, + assistant_id: str, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> AssistantFile: """Retrieve file information. Args: file_id (str): The file if. assistant_id (str): The assistant id of the file. - workspace (str, optional): The DashScope workspace id. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): Your DashScope api key. Defaults to None. Returns: AssistantFile: The `AssistantFile` object. """ if not assistant_id or not file_id: - raise InputRequired('assistant id and file id are required!') + raise InputRequired("assistant id and file id are required!") response = super().get( file_id, - path=f'assistants/{assistant_id}/files/{file_id}', + path=f"assistants/{assistant_id}/files/{file_id}", api_key=api_key, flattened_output=True, workspace=workspace, - **kwargs) + **kwargs, + ) return AssistantFile(**response) @classmethod - def get(cls, - file_id: str, - *, - assistant_id: str, - workspace: str = None, - api_key: str = None, - **kwargs) -> Optional[AssistantFile]: + def get( # type: ignore[override] + cls, + file_id: str, + *, + assistant_id: str, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> Optional[AssistantFile]: """Retrieve file information. Args: file_id (str): The file if. assistant_id (str): The assistant id of the file. - workspace (str, optional): The DashScope workspace id. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): Your DashScope api key. Defaults to None. Returns: AssistantFile: The `AssistantFile` object. """ - response = super().get(target=assistant_id + '/files/' + file_id, api_key=api_key, workspace=workspace, **kwargs) + response = super().get( + target=assistant_id + "/files/" + file_id, + api_key=api_key, + workspace=workspace, + **kwargs, + ) if response.status_code == 200 and response.output: return AssistantFile(**response.output) else: return None @classmethod - def delete(cls, - file_id: str, - *, - assistant_id: str, - workspace: str = None, - api_key: str = None, - **kwargs) -> DeleteResponse: + def delete( # type: ignore[override] + cls, + file_id: str, + *, + assistant_id: str, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> DeleteResponse: """Delete the `file_id`. Args: file_id (str): The file to be deleted. assistant_id (str): The assistant id of the file. - workspace (str, optional): The DashScope workspace id. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): Your DashScope api key. Defaults to None. Returns: @@ -189,9 +227,10 @@ def delete(cls, response = super().delete( file_id, - path=f'assistants/{assistant_id}/files/{file_id}', + path=f"assistants/{assistant_id}/files/{file_id}", api_key=api_key, flattened_output=True, workspace=workspace, - **kwargs) + **kwargs, + ) return DeleteResponse(**response) diff --git a/dashscope/audio/__init__.py b/dashscope/audio/__init__.py index 7ac13e5..bdf170a 100644 --- a/dashscope/audio/__init__.py +++ b/dashscope/audio/__init__.py @@ -1,5 +1,13 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from . import asr, tts, tts_v2, qwen_tts, qwen_tts_realtime, qwen_omni -__all__ = [asr, tts, tts_v2, qwen_tts, qwen_tts_realtime, qwen_omni] +__all__ = [ # type: ignore[misc] + asr, + tts, + tts_v2, + qwen_tts, + qwen_tts_realtime, + qwen_omni, +] # noqa: E501 diff --git a/dashscope/audio/asr/__init__.py b/dashscope/audio/asr/__init__.py index 4b5c4eb..3f68871 100644 --- a/dashscope/audio/asr/__init__.py +++ b/dashscope/audio/asr/__init__.py @@ -1,20 +1,33 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from .asr_phrase_manager import AsrPhraseManager from .recognition import Recognition, RecognitionCallback, RecognitionResult from .transcription import Transcription -from .translation_recognizer import (TranscriptionResult, Translation, - TranslationRecognizerCallback, - TranslationRecognizerChat, - TranslationRecognizerRealtime, - TranslationRecognizerResultPack, - TranslationResult) +from .translation_recognizer import ( + TranscriptionResult, + Translation, + TranslationRecognizerCallback, + TranslationRecognizerChat, + TranslationRecognizerRealtime, + TranslationRecognizerResultPack, + TranslationResult, +) from .vocabulary import VocabularyService, VocabularyServiceException __all__ = [ - 'Transcription', 'Recognition', 'RecognitionCallback', 'RecognitionResult', - 'AsrPhraseManager', 'VocabularyServiceException', 'VocabularyService', - 'TranslationRecognizerRealtime', 'TranslationRecognizerChat', - 'TranslationRecognizerCallback', 'Translation', 'TranslationResult', - 'TranscriptionResult', 'TranslationRecognizerResultPack' + "Transcription", + "Recognition", + "RecognitionCallback", + "RecognitionResult", + "AsrPhraseManager", + "VocabularyServiceException", + "VocabularyService", + "TranslationRecognizerRealtime", + "TranslationRecognizerChat", + "TranslationRecognizerCallback", + "Translation", + "TranslationResult", + "TranscriptionResult", + "TranslationRecognizerResultPack", ] diff --git a/dashscope/audio/asr/asr_phrase_manager.py b/dashscope/audio/asr/asr_phrase_manager.py index 9a76248..4d51b18 100644 --- a/dashscope/audio/asr/asr_phrase_manager.py +++ b/dashscope/audio/asr/asr_phrase_manager.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from http import HTTPStatus @@ -11,15 +12,17 @@ class AsrPhraseManager(BaseAsyncApi): - """Hot word management for speech recognition. - """ + """Hot word management for speech recognition.""" + @classmethod - def create_phrases(cls, - model: str, - phrases: Dict[str, Any], - training_type: str = 'compile_asr_phrase', - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def create_phrases( + cls, + model: str, + phrases: Dict[str, Any], + training_type: str = "compile_asr_phrase", + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Create hot words. Args: @@ -37,34 +40,38 @@ def create_phrases(cls, DashScopeAPIResponse: The results of creating hot words. """ if phrases is None or len(phrases) == 0: - raise InvalidParameter('phrases is empty!') + raise InvalidParameter("phrases is empty!") if training_type is None or len(training_type) == 0: - raise InvalidParameter('training_type is empty!') + raise InvalidParameter("training_type is empty!") original_ft_sub_path = FineTunes.SUB_PATH - FineTunes.SUB_PATH = 'fine-tunes' - response = FineTunes.call(model=model, - training_file_ids=[], - validation_file_ids=[], - mode=training_type, - hyper_parameters={'phrase_list': phrases}, - workspace=workspace, - **kwargs) + FineTunes.SUB_PATH = "fine-tunes" + response = FineTunes.call( + model=model, + training_file_ids=[], + validation_file_ids=[], + mode=training_type, + hyper_parameters={"phrase_list": phrases}, + workspace=workspace, + **kwargs, + ) FineTunes.SUB_PATH = original_ft_sub_path if response.status_code != HTTPStatus.OK: - logger.error('Create phrase failed, ' + str(response)) + logger.error("Create phrase failed, %s", response) - return response + return response # type: ignore[return-value] @classmethod - def update_phrases(cls, - model: str, - phrase_id: str, - phrases: Dict[str, Any], - training_type: str = 'compile_asr_phrase', - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def update_phrases( + cls, + model: str, + phrase_id: str, + phrases: Dict[str, Any], + training_type: str = "compile_asr_phrase", + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Update the hot words marked phrase_id. Args: @@ -84,34 +91,38 @@ def update_phrases(cls, DashScopeAPIResponse: The results of updating hot words. """ if phrase_id is None or len(phrase_id) == 0: - raise InvalidParameter('phrase_id is empty!') + raise InvalidParameter("phrase_id is empty!") if phrases is None or len(phrases) == 0: - raise InvalidParameter('phrases is empty!') + raise InvalidParameter("phrases is empty!") if training_type is None or len(training_type) == 0: - raise InvalidParameter('training_type is empty!') + raise InvalidParameter("training_type is empty!") original_ft_sub_path = FineTunes.SUB_PATH - FineTunes.SUB_PATH = 'fine-tunes' - response = FineTunes.call(model=model, - training_file_ids=[], - validation_file_ids=[], - mode=training_type, - hyper_parameters={'phrase_list': phrases}, - finetuned_output=phrase_id, - workspace=workspace, - **kwargs) + FineTunes.SUB_PATH = "fine-tunes" + response = FineTunes.call( + model=model, + training_file_ids=[], + validation_file_ids=[], + mode=training_type, + hyper_parameters={"phrase_list": phrases}, + finetuned_output=phrase_id, + workspace=workspace, + **kwargs, + ) FineTunes.SUB_PATH = original_ft_sub_path if response.status_code != HTTPStatus.OK: - logger.error('Update phrase failed, ' + str(response)) + logger.error("Update phrase failed, %s", response) - return response + return response # type: ignore[return-value] @classmethod - def query_phrases(cls, - phrase_id: str, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def query_phrases( + cls, + phrase_id: str, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Query the hot words by phrase_id. Args: @@ -126,26 +137,30 @@ def query_phrases(cls, AsrPhraseManagerResult: The results of querying hot words. """ if phrase_id is None or len(phrase_id) == 0: - raise InvalidParameter('phrase_id is empty!') + raise InvalidParameter("phrase_id is empty!") original_ft_sub_path = FineTunes.SUB_PATH - FineTunes.SUB_PATH = 'fine-tunes/outputs' - response = FineTunes.get(job_id=phrase_id, - workspace=workspace, - **kwargs) + FineTunes.SUB_PATH = "fine-tunes/outputs" + response = FineTunes.get( + job_id=phrase_id, + workspace=workspace, + **kwargs, + ) FineTunes.SUB_PATH = original_ft_sub_path if response.status_code != HTTPStatus.OK: - logger.error('Query phrase failed, ' + str(response)) + logger.error("Query phrase failed, %s", response) - return response + return response # type: ignore[return-value] @classmethod - def list_phrases(cls, - page: int = 1, - page_size: int = 10, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def list_phrases( + cls, + page: int = 1, + page_size: int = 10, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """List all information of phrases. Args: @@ -158,23 +173,27 @@ def list_phrases(cls, DashScopeAPIResponse: The results of listing hot words. """ original_ft_sub_path = FineTunes.SUB_PATH - FineTunes.SUB_PATH = 'fine-tunes/outputs' - response = FineTunes.list(page=page, - page_size=page_size, - workspace=workspace, - **kwargs) + FineTunes.SUB_PATH = "fine-tunes/outputs" + response = FineTunes.list( + page=page, + page_size=page_size, + workspace=workspace, + **kwargs, + ) FineTunes.SUB_PATH = original_ft_sub_path if response.status_code != HTTPStatus.OK: - logger.error('List phrase failed, ' + str(response)) + logger.error("List phrase failed, %s", response) - return response + return response # type: ignore[return-value] @classmethod - def delete_phrases(cls, - phrase_id: str, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def delete_phrases( + cls, + phrase_id: str, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Delete the hot words by phrase_id. Args: @@ -188,16 +207,18 @@ def delete_phrases(cls, DashScopeAPIResponse: The results of deleting hot words. """ if phrase_id is None or len(phrase_id) == 0: - raise InvalidParameter('phrase_id is empty!') + raise InvalidParameter("phrase_id is empty!") original_ft_sub_path = FineTunes.SUB_PATH - FineTunes.SUB_PATH = 'fine-tunes/outputs' - response = FineTunes.delete(job_id=phrase_id, - workspace=workspace, - **kwargs) + FineTunes.SUB_PATH = "fine-tunes/outputs" + response = FineTunes.delete( + job_id=phrase_id, + workspace=workspace, + **kwargs, + ) FineTunes.SUB_PATH = original_ft_sub_path if response.status_code != HTTPStatus.OK: - logger.error('Delete phrase failed, ' + str(response)) + logger.error("Delete phrase failed, %s", response) - return response + return response # type: ignore[return-value] diff --git a/dashscope/audio/asr/recognition.py b/dashscope/audio/asr/recognition.py index a0d2f02..4a9c8f8 100644 --- a/dashscope/audio/asr/recognition.py +++ b/dashscope/audio/asr/recognition.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -13,9 +14,13 @@ from dashscope.api_entities.dashscope_response import RecognitionResponse from dashscope.client.base_api import BaseApi from dashscope.common.constants import ApiProtocol -from dashscope.common.error import (InputDataRequired, InputRequired, - InvalidParameter, InvalidTask, - ModelRequired) +from dashscope.common.error import ( + InputDataRequired, + InputRequired, + InvalidParameter, + InvalidTask, + ModelRequired, +) from dashscope.common.logging import logger from dashscope.common.utils import _get_task_group_and_task from dashscope.protocol.websocket import WebsocketStreamingMode @@ -23,73 +28,89 @@ class RecognitionResult(RecognitionResponse): """The result set of speech recognition, including the single-sentence - recognition result returned by the callback mode, and all recognition - results in a synchronized manner. + recognition result returned by the callback mode, and all recognition + results in a synchronized manner. """ - def __init__(self, - response: RecognitionResponse, - sentences: List[Any] = None, - usages: List[Any] = None): + + def __init__( + self, + response: RecognitionResponse, + sentences: List[Any] = None, + usages: List[Any] = None, + ): self.status_code = response.status_code self.request_id = response.request_id self.code = response.code self.message = response.message self.usages = usages if sentences is not None and len(sentences) > 0: - self.output = {'sentence': sentences} + self.output = {"sentence": sentences} else: self.output = response.output - if self.usages is not None and len( - self.usages) > 0 and 'usage' in self.usages[-1]: - self.usage = self.usages[-1]['usage'] + if ( + self.usages is not None + and len( + self.usages, + ) + > 0 + and "usage" in self.usages[-1] + ): + self.usage = self.usages[-1]["usage"] else: self.usage = None def __str__(self): - return json.dumps(RecognitionResponse.from_api_response(self), - ensure_ascii=False) + return json.dumps( + RecognitionResponse.from_api_response(self), + ensure_ascii=False, + ) def get_sentence(self) -> Union[Dict[str, Any], List[Any]]: - """The result of speech recognition. - """ - if self.output and 'sentence' in self.output: - return self.output['sentence'] + """The result of speech recognition.""" + if self.output and "sentence" in self.output: + return self.output["sentence"] else: - return None + return None # type: ignore[return-value] def get_request_id(self) -> str: - """The request_id of speech recognition. - """ + """The request_id of speech recognition.""" return self.request_id def get_usage(self, sentence: Dict[str, Any]) -> Dict[str, Any]: - """Get billing for the input sentence. - """ + """Get billing for the input sentence.""" if self.usages is not None: - if sentence is not None and 'end_time' in sentence and sentence[ - 'end_time'] is not None: + if ( + sentence is not None + and "end_time" in sentence + and sentence["end_time"] is not None + ): for usage in self.usages: - if usage['end_time'] == sentence['end_time']: - return usage['usage'] + if usage["end_time"] == sentence["end_time"]: + return usage["usage"] - return None + return None # type: ignore[return-value] @staticmethod def is_sentence_end(sentence: Dict[str, Any]) -> bool: - """Determine whether the speech recognition result is the end of a sentence. - This is a static method. + """Determine whether the speech recognition result is the end of a sentence. # noqa: E501 + This is a static method. """ - if sentence is not None and 'end_time' in sentence and sentence[ - 'end_time'] is not None: + # pylint: disable=simplifiable-if-statement + if ( + sentence is not None + and "end_time" in sentence + and sentence["end_time"] is not None + ): return True else: return False -class RecognitionCallback(): - """An interface that defines callback methods for getting speech recognition results. # noqa E501 - Derive from this class and implement its function to provide your own data. +class RecognitionCallback: + """An interface that defines callback methods for getting speech recognition results. # noqa E501 # pylint: disable=line-too-long + Derive from this class and implement its function to provide your own data. """ + def on_open(self) -> None: pass @@ -136,19 +157,21 @@ class Recognition(BaseApi): SILENCE_TIMEOUT_S = 23 - def __init__(self, - model: str, - callback: RecognitionCallback, - format: str, - sample_rate: int, - workspace: str = None, - **kwargs): + def __init__( + self, + model: str, + callback: RecognitionCallback, + format: str, # pylint: disable=redefined-builtin + sample_rate: int, + workspace: str = None, + **kwargs, + ): if model is None: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") if format is None: - raise InputRequired('format is required!') + raise InputRequired("format is required!") if sample_rate is None: - raise InputRequired('sample_rate is required!') + raise InputRequired("sample_rate is required!") self.model = model self.format = format @@ -175,7 +198,9 @@ def __del__(self): self._stream_data = Queue() if self._worker is not None and self._worker.is_alive(): self._worker.join() - if self._silence_timer is not None and self._silence_timer.is_alive( # noqa E501 + if ( + self._silence_timer is not None + and self._silence_timer.is_alive() # noqa E501 ): self._silence_timer.cancel() self._silence_timer = None @@ -184,84 +209,112 @@ def __del__(self): def __receive_worker(self): """Asynchronously, initiate a real-time speech recognition request and - obtain the result for parsing. + obtain the result for parsing. """ responses = self.__launch_request() - for part in responses: + for part in responses: # pylint: disable=R1702 if part.status_code == HTTPStatus.OK: - if len(part.output) == 0 or ('finished' in part.output and part.output['finished'] == True): + if len(part.output) == 0 or ( + "finished" in part.output + # pylint: disable=singleton-comparison + and part.output["finished"] == True # noqa: E712 + ): self._on_complete_timestamp = time.time() * 1000 - logger.debug('last package delay {}'.format( - self.get_last_package_delay())) + logger.debug( + "last package delay %s", + self.get_last_package_delay(), + ) self._callback.on_complete() else: usage: Dict[str, Any] = None usages: List[Any] = None - if 'sentence' in part.output: - if 'text' in part.output['sentence'] and part.output['sentence']['text'] != '': - if (self._first_package_timestamp < 0): - self._first_package_timestamp = time.time() * 1000 - logger.debug('first package delay {}'.format( - self.get_first_package_delay())) - sentence = part.output['sentence'] - if 'heartbeat' in sentence and sentence['heartbeat'] == True: - logger.debug('recv heartbeat') + if "sentence" in part.output: + if ( + "text" in part.output["sentence"] + and part.output["sentence"]["text"] != "" + ): + if self._first_package_timestamp < 0: + self._first_package_timestamp = ( + time.time() * 1000 + ) + logger.debug( + "first package delay %s", + self.get_first_package_delay(), + ) + sentence = part.output["sentence"] + if ( + "heartbeat" in sentence + # pylint: disable=singleton-comparison + and sentence["heartbeat"] == True # noqa: E712 + ): + logger.debug("recv heartbeat") continue logger.debug( - 'Recv Result [rid:{}]:{}, isEnd: {}'.format( - part.request_id, sentence, - RecognitionResult.is_sentence_end(sentence))) + "Recv Result [rid:%s]:%s, isEnd: %s", + part.request_id, + sentence, + RecognitionResult.is_sentence_end(sentence), + ) if part.usage is not None: usage = { - 'end_time': - part.output['sentence']['end_time'], - 'usage': part.usage + "end_time": part.output["sentence"][ + "end_time" + ], + "usage": part.usage, } usages = [usage] - if self.request_id_confirmed is False and part.request_id is not None: + if ( + self.request_id_confirmed is False + and part.request_id is not None + ): self.last_request_id = part.request_id self.request_id_confirmed = True self._callback.on_event( RecognitionResult( RecognitionResponse.from_api_response(part), - usages=usages)) + usages=usages, + ), + ) else: self._running = False self._stream_data = Queue() self._callback.on_error( RecognitionResult( - RecognitionResponse.from_api_response(part))) + RecognitionResponse.from_api_response(part), + ), + ) self._callback.on_close() break def __launch_request(self): - """Initiate real-time speech recognition requests. - """ + """Initiate real-time speech recognition requests.""" resources_list: list = [] if self._phrase is not None and len(self._phrase) > 0: - item = {'resource_id': self._phrase, 'resource_type': 'asr_phrase'} + item = {"resource_id": self._phrase, "resource_type": "asr_phrase"} resources_list.append(item) if len(resources_list) > 0: - self._kwargs['resources'] = resources_list + self._kwargs["resources"] = resources_list self._tidy_kwargs() task_name, _ = _get_task_group_and_task(__name__) - responses = super().call(model=self.model, - task_group='audio', - task=task_name, - function='recognition', - input=self._input_stream_cycle(), - api_protocol=ApiProtocol.WEBSOCKET, - ws_stream_mode=WebsocketStreamingMode.DUPLEX, - is_binary_input=True, - sample_rate=self.sample_rate, - format=self.format, - stream=True, - workspace=self._workspace, - pre_task_id=self.last_request_id, - **self._kwargs) + responses = super().call( + model=self.model, + task_group="audio", + task=task_name, + function="recognition", + input=self._input_stream_cycle(), + api_protocol=ApiProtocol.WEBSOCKET, + ws_stream_mode=WebsocketStreamingMode.DUPLEX, + is_binary_input=True, + sample_rate=self.sample_rate, + format=self.format, + stream=True, + workspace=self._workspace, + pre_task_id=self.last_request_id, + **self._kwargs, + ) return responses def start(self, phrase_id: str = None, **kwargs): @@ -288,10 +341,12 @@ def start(self, phrase_id: str = None, **kwargs): if it has already been started. InvalidTask: Task create failed. """ - assert self._callback is not None, 'Please set the callback to get the speech recognition result.' # noqa E501 + assert ( + self._callback is not None + ), "Please set the callback to get the speech recognition result." # noqa E501 if self._running: - raise InvalidParameter('Speech recognition has started.') + raise InvalidParameter("Speech recognition has started.") self._start_stream_timestamp = -1 self._first_package_timestamp = -1 @@ -307,17 +362,22 @@ def start(self, phrase_id: str = None, **kwargs): self._callback.on_open() # If audio data is not received for 23 seconds, the timeout exits - self._silence_timer = Timer(Recognition.SILENCE_TIMEOUT_S, - self._silence_stop_timer) + self._silence_timer = Timer( + Recognition.SILENCE_TIMEOUT_S, + self._silence_stop_timer, + ) self._silence_timer.start() else: self._running = False - raise InvalidTask('Invalid task, task create failed.') - - def call(self, - file: str, - phrase_id: str = None, - **kwargs) -> RecognitionResult: + raise InvalidTask("Invalid task, task create failed.") + + # pylint: disable=R1702,too-many-branches,too-many-statements + def call( # type: ignore[override] # noqa: E501 + self, + file: str, + phrase_id: str = None, + **kwargs, + ) -> RecognitionResult: """Real-time speech recognition in synchronous mode. Args: @@ -346,13 +406,13 @@ def call(self, """ self._start_stream_timestamp = time.time() * 1000 if self._running: - raise InvalidParameter('Speech recognition has been called.') + raise InvalidParameter("Speech recognition has been called.") if os.path.exists(file): if os.path.isdir(file): - raise IsADirectoryError('Is a directory: ' + file) + raise IsADirectoryError("Is a directory: " + file) else: - raise FileNotFoundError('No such file or directory: ' + file) + raise FileNotFoundError("No such file or directory: " + file) self._recognition_once = True self._stream_data = Queue() @@ -366,17 +426,20 @@ def call(self, try: audio_data: bytes = None - f = open(file, 'rb') + # pylint: disable=consider-using-with + f = open(file, "rb") if os.path.getsize(file): while True: audio_data = f.read(12800) if not audio_data: break - else: - self._stream_data.put(audio_data) + self._stream_data.put( + audio_data, + ) # pylint: disable=no-else-break else: raise InputDataRequired( - 'The supplied file was empty (zero bytes long)') + "The supplied file was empty (zero bytes long)", + ) f.close() self._stop_stream_timestamp = time.time() * 1000 except Exception as e: @@ -388,26 +451,36 @@ def call(self, responses = self.__launch_request() for part in responses: if part.status_code == HTTPStatus.OK: - if 'sentence' in part.output: - if 'text' in part.output['sentence'] and part.output['sentence']['text'] != '': - if (self._first_package_timestamp < 0): - self._first_package_timestamp = time.time() * 1000 - logger.debug('first package delay {}'.format( - self._first_package_timestamp - - self._start_stream_timestamp)) - sentence = part.output['sentence'] + if "sentence" in part.output: + if ( + "text" in part.output["sentence"] + and part.output["sentence"]["text"] != "" + ): + if self._first_package_timestamp < 0: + self._first_package_timestamp = ( + time.time() * 1000 + ) + logger.debug( + "first package delay %s", + self._first_package_timestamp + - self._start_stream_timestamp, + ) + sentence = part.output["sentence"] logger.debug( - 'Recv Result [rid:{}]:{}, isEnd: {}'.format( - part.request_id, sentence, - RecognitionResult.is_sentence_end(sentence))) + "Recv Result [rid:%s]:%s, isEnd: %s", + part.request_id, + sentence, + RecognitionResult.is_sentence_end(sentence), + ) if RecognitionResult.is_sentence_end(sentence): sentences.append(sentence) if part.usage is not None: usage = { - 'end_time': - part.output['sentence']['end_time'], - 'usage': part.usage + "end_time": part.output["sentence"][ + "end_time" + ], + "usage": part.usage, } usages.append(usage) @@ -419,8 +492,10 @@ def call(self, break self._on_complete_timestamp = time.time() * 1000 - logger.debug('last package delay {}'.format( - self.get_last_package_delay())) + logger.debug( + "last package delay %s", + self.get_last_package_delay(), + ) if error_flag: result = RecognitionResult(response) @@ -440,7 +515,7 @@ def stop(self): InvalidParameter: Cannot stop an uninitiated recognition. """ if self._running is False: - raise InvalidParameter('Speech recognition has stopped.') + raise InvalidParameter("Speech recognition has stopped.") self._stop_stream_timestamp = time.time() * 1000 @@ -461,11 +536,11 @@ def send_audio_frame(self, buffer: bytes): InvalidParameter: Cannot send data to an uninitiated recognition. """ if self._running is False: - raise InvalidParameter('Speech recognition has stopped.') + raise InvalidParameter("Speech recognition has stopped.") - if (self._start_stream_timestamp < 0): + if self._start_stream_timestamp < 0: self._start_stream_timestamp = time.time() * 1000 - logger.debug('send_audio_frame: {}'.format(len(buffer))) + logger.debug("send_audio_frame: %s", len(buffer)) self._stream_data.put(buffer) def _tidy_kwargs(self): @@ -479,15 +554,18 @@ def _input_stream_cycle(self): if self._running: time.sleep(0.01) continue - else: - break + break # Reset silence_timer when getting stream. - if self._silence_timer is not None and self._silence_timer.is_alive( # noqa E501 + if ( + self._silence_timer is not None + and self._silence_timer.is_alive() # noqa E501 ): self._silence_timer.cancel() - self._silence_timer = Timer(Recognition.SILENCE_TIMEOUT_S, - self._silence_stop_timer) + self._silence_timer = Timer( + Recognition.SILENCE_TIMEOUT_S, + self._silence_stop_timer, + ) self._silence_timer.start() while not self._stream_data.empty(): @@ -504,8 +582,7 @@ def _input_stream_cycle(self): yield bytes(frame) def _silence_stop_timer(self): - """If audio data is not received for a long time, exit worker. - """ + """If audio data is not received for a long time, exit worker.""" self._running = False if self._silence_timer is not None and self._silence_timer.is_alive(): self._silence_timer.cancel() @@ -515,13 +592,11 @@ def _silence_stop_timer(self): self._stream_data = Queue() def get_first_package_delay(self): - """First Package Delay is the time between start sending audio and receive first words package - """ + """First Package Delay is the time between start sending audio and receive first words package""" # noqa: E501 # pylint: disable=line-too-long return self._first_package_timestamp - self._start_stream_timestamp def get_last_package_delay(self): - """Last Package Delay is the time between stop sending audio and receive last words package - """ + """Last Package Delay is the time between stop sending audio and receive last words package""" # noqa: E501 # pylint: disable=line-too-long return self._on_complete_timestamp - self._stop_stream_timestamp # 获取上一个任务的taskId diff --git a/dashscope/audio/asr/transcription.py b/dashscope/audio/asr/transcription.py index 7959e49..aa1e01a 100644 --- a/dashscope/audio/asr/transcription.py +++ b/dashscope/audio/asr/transcription.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio @@ -6,8 +7,10 @@ import aiohttp -from dashscope.api_entities.dashscope_response import (DashScopeAPIResponse, - TranscriptionResponse) +from dashscope.api_entities.dashscope_response import ( + DashScopeAPIResponse, + TranscriptionResponse, +) from dashscope.client.base_api import BaseAsyncApi from dashscope.common.constants import ApiProtocol, HTTPMethod from dashscope.common.logging import logger @@ -15,24 +18,25 @@ class Transcription(BaseAsyncApi): - """API for File Transcription models. - """ + """API for File Transcription models.""" MAX_QUERY_TRY_COUNT = 3 class Models: - paraformer_v1 = 'paraformer-v1' - paraformer_8k_v1 = 'paraformer-8k-v1' - paraformer_mtl_v1 = 'paraformer-mtl-v1' + paraformer_v1 = "paraformer-v1" + paraformer_8k_v1 = "paraformer-8k-v1" + paraformer_mtl_v1 = "paraformer-mtl-v1" @classmethod - def call(cls, - model: str, - file_urls: List[str], - phrase_id: str = None, - api_key: str = None, - workspace: str = None, - **kwargs) -> TranscriptionResponse: + def call( # type: ignore[override] + cls, + model: str, + file_urls: List[str], + phrase_id: str = None, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> TranscriptionResponse: """Transcribe the given files synchronously. Args: @@ -60,21 +64,25 @@ def call(cls, """ kwargs.update(cls._fill_resource_id(phrase_id, **kwargs)) kwargs = cls._tidy_kwargs(**kwargs) - response = super().call(model, - file_urls, - api_key=api_key, - workspace=workspace, - **kwargs) + response = super().call( + model, + file_urls, + api_key=api_key, + workspace=workspace, + **kwargs, + ) return TranscriptionResponse.from_api_response(response) @classmethod - def async_call(cls, - model: str, - file_urls: List[str], - phrase_id: str = None, - api_key: str = None, - workspace: str = None, - **kwargs) -> TranscriptionResponse: + def async_call( # type: ignore[override] + cls, + model: str, + file_urls: List[str], + phrase_id: str = None, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> TranscriptionResponse: """Transcribe the given files asynchronously, return the status of task submission for querying results subsequently. @@ -103,20 +111,24 @@ def async_call(cls, """ kwargs.update(cls._fill_resource_id(phrase_id, **kwargs)) kwargs = cls._tidy_kwargs(**kwargs) - response = cls._launch_request(model, - file_urls, - api_key=api_key, - workspace=workspace, - **kwargs) + response = cls._launch_request( + model, + file_urls, + api_key=api_key, + workspace=workspace, + **kwargs, + ) return TranscriptionResponse.from_api_response(response) @classmethod - def fetch(cls, - task: Union[str, TranscriptionResponse], - api_key: str = None, - workspace: str = None, - **kwargs) -> TranscriptionResponse: - """Fetch the status of task, including results of batch transcription when task_status is SUCCEEDED. # noqa: E501 + def fetch( + cls, + task: Union[str, TranscriptionResponse], # type: ignore[override] + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> TranscriptionResponse: + """Fetch the status of task, including results of batch transcription when task_status is SUCCEEDED. # noqa: E501 # pylint: disable=line-too-long Args: task (Union[str, TranscriptionResponse]): The task_id or @@ -130,10 +142,12 @@ def fetch(cls, try_count: int = 0 while True: try: - response = super().fetch(task, - api_key=api_key, - workspace=workspace, - **kwargs) + response = super().fetch( + task, + api_key=api_key, + workspace=workspace, + **kwargs, + ) except (asyncio.TimeoutError, aiohttp.ClientConnectorError) as e: logger.error(e) try_count += 1 @@ -147,11 +161,13 @@ def fetch(cls, return TranscriptionResponse.from_api_response(response) @classmethod - def wait(cls, - task: Union[str, TranscriptionResponse], - api_key: str = None, - workspace: str = None, - **kwargs) -> TranscriptionResponse: + def wait( + cls, + task: Union[str, TranscriptionResponse], # type: ignore[override] + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> TranscriptionResponse: """Poll task until the final results of transcription is obtained. Args: @@ -162,19 +178,23 @@ def wait(cls, Returns: TranscriptionResponse: The result of batch transcription. """ - response = super().wait(task, - api_key=api_key, - workspace=workspace, - **kwargs) + response = super().wait( + task, + api_key=api_key, + workspace=workspace, + **kwargs, + ) return TranscriptionResponse.from_api_response(response) @classmethod - def _launch_request(cls, - model: str, - files: List[str], - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def _launch_request( + cls, + model: str, + files: List[str], + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Submit transcribe request. Args: @@ -190,16 +210,18 @@ def _launch_request(cls, try_count: int = 0 while True: try: - response = super().async_call(model=model, - task_group='audio', - task=task_name, - function=function, - input={'file_urls': files}, - api_protocol=ApiProtocol.HTTP, - http_method=HTTPMethod.POST, - api_key=api_key, - workspace=workspace, - **kwargs) + response = super().async_call( + model=model, + task_group="audio", + task=task_name, + function=function, + input={"file_urls": files}, + api_protocol=ApiProtocol.HTTP, + http_method=HTTPMethod.POST, + api_key=api_key, + workspace=workspace, + **kwargs, + ) except (asyncio.TimeoutError, aiohttp.ClientConnectorError) as e: logger.error(e) try_count += 1 @@ -215,11 +237,11 @@ def _launch_request(cls, def _fill_resource_id(cls, phrase_id: str, **kwargs): resources_list: list = [] if phrase_id is not None and len(phrase_id) > 0: - item = {'resource_id': phrase_id, 'resource_type': 'asr_phrase'} + item = {"resource_id": phrase_id, "resource_type": "asr_phrase"} resources_list.append(item) if len(resources_list) > 0: - kwargs['resources'] = resources_list + kwargs["resources"] = resources_list return kwargs diff --git a/dashscope/audio/asr/translation_recognizer.py b/dashscope/audio/asr/translation_recognizer.py index 3574d56..83dcaa5 100644 --- a/dashscope/audio/asr/translation_recognizer.py +++ b/dashscope/audio/asr/translation_recognizer.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -12,15 +13,19 @@ from dashscope.client.base_api import BaseApi from dashscope.common.constants import ApiProtocol -from dashscope.common.error import (InputDataRequired, InputRequired, - InvalidParameter, InvalidTask, - ModelRequired) +from dashscope.common.error import ( + InputDataRequired, + InputRequired, + InvalidParameter, + InvalidTask, + ModelRequired, +) from dashscope.common.logging import logger from dashscope.common.utils import _get_task_group_and_task from dashscope.protocol.websocket import WebsocketStreamingMode -DASHSCOPE_TRANSLATION_KEY = 'translations' -DASHSCOPE_TRANSCRIPTION_KEY = 'transcription' +DASHSCOPE_TRANSLATION_KEY = "translations" +DASHSCOPE_TRANSCRIPTION_KEY = "transcription" class ThreadSafeBool: @@ -37,8 +42,8 @@ def get(self): return self._value -class WordObj(): - def __init__(self, ) -> None: +class WordObj: + def __init__(self) -> None: self.text: str = None self.begin_time: int = None self.end_time: int = None @@ -47,25 +52,24 @@ def __init__(self, ) -> None: @staticmethod def from_json(json_data: Dict[str, Any]): - """Create a Word object from a JSON dictionary. - """ + """Create a Word object from a JSON dictionary.""" word = WordObj() - word.text = json_data['text'] - word.begin_time = json_data['begin_time'] - word.end_time = json_data['end_time'] - word.fixed = json_data['fixed'] - word._raw_data = json_data + word.text = json_data["text"] + word.begin_time = json_data["begin_time"] + word.end_time = json_data["end_time"] + word.fixed = json_data["fixed"] + word._raw_data = json_data # pylint: disable=protected-access return word def __str__(self) -> str: - return 'Word: ' + json.dumps(self._raw_data, ensure_ascii=False) + return "Word: " + json.dumps(self._raw_data, ensure_ascii=False) def __repr__(self): return self.__str__() -class SentenceBaseObj(): - def __init__(self, ) -> None: +class SentenceBaseObj: + def __init__(self) -> None: self.sentence_id: int = -1 self.text: str = None self.begin_time: int = None @@ -75,20 +79,19 @@ def __init__(self, ) -> None: @staticmethod def from_json(json_data: Dict[str, Any]): - """Create a SentenceBase object from a JSON dictionary. - """ + """Create a SentenceBase object from a JSON dictionary.""" sentence = SentenceBaseObj() - sentence.sentence_id = json_data['sentence_id'] - sentence.text = json_data['text'] - sentence.begin_time = json_data['begin_time'] - if json_data.get('end_time') is not None: - sentence.end_time = json_data['end_time'] + sentence.sentence_id = json_data["sentence_id"] + sentence.text = json_data["text"] + sentence.begin_time = json_data["begin_time"] + if json_data.get("end_time") is not None: + sentence.end_time = json_data["end_time"] else: - sentence.end_time = json_data['current_time'] + sentence.end_time = json_data["current_time"] sentence.words = [ - WordObj.from_json(word) for word in json_data['words'] + WordObj.from_json(word) for word in json_data["words"] ] - sentence._raw_data = json_data + sentence._raw_data = json_data # pylint: disable=protected-access return sentence def __str__(self) -> str: @@ -99,7 +102,7 @@ def __repr__(self): class TranscriptionResult(SentenceBaseObj): - def __init__(self, ) -> None: + def __init__(self) -> None: self.stash: SentenceBaseObj = None self.is_sentence_end = False # vad related @@ -112,44 +115,46 @@ def __init__(self, ) -> None: @staticmethod def from_json(json_data: Dict[str, Any]): - """Create a TranscriptionResult object from a JSON dictionary. - """ + """Create a TranscriptionResult object from a JSON dictionary.""" transcription = TranscriptionResult() - transcription.sentence_id = json_data['sentence_id'] - transcription.text = json_data['text'] - transcription.begin_time = json_data['begin_time'] - if json_data.get('end_time') is not None: - transcription.end_time = json_data['end_time'] + transcription.sentence_id = json_data["sentence_id"] + transcription.text = json_data["text"] + transcription.begin_time = json_data["begin_time"] + if json_data.get("end_time") is not None: + transcription.end_time = json_data["end_time"] else: - transcription.end_time = json_data['current_time'] + transcription.end_time = json_data["current_time"] transcription.words = [ - WordObj.from_json(word) for word in json_data['words'] + WordObj.from_json(word) for word in json_data["words"] ] - transcription._raw_data = json_data - transcription.is_sentence_end = json_data.get('sentence_end') - if 'stash' in json_data: - transcription.stash = SentenceBaseObj.from_json(json_data['stash']) - if 'vad_pre_end' in json_data: - transcription.vad_pre_end = json_data['vad_pre_end'] - if 'pre_end_failed' in json_data: - transcription.pre_end_failed = json_data['pre_end_failed'] - if 'pre_end_start_time' in json_data: - transcription.pre_end_start_time = json_data['pre_end_start_time'] - if 'pre_end_end_time' in json_data: - transcription.pre_end_end_time = json_data['pre_end_end_time'] - transcription._raw_data = json_data + # Store raw JSON data for later use + transcription._raw_data = json_data # pylint: disable=protected-access + transcription.is_sentence_end = json_data.get("sentence_end") + if "stash" in json_data: + transcription.stash = SentenceBaseObj.from_json(json_data["stash"]) + if "vad_pre_end" in json_data: + transcription.vad_pre_end = json_data["vad_pre_end"] + if "pre_end_failed" in json_data: + transcription.pre_end_failed = json_data["pre_end_failed"] + if "pre_end_start_time" in json_data: + transcription.pre_end_start_time = json_data["pre_end_start_time"] + if "pre_end_end_time" in json_data: + transcription.pre_end_end_time = json_data["pre_end_end_time"] + transcription._raw_data = json_data # pylint: disable=protected-access return transcription def __str__(self) -> str: - return 'Transcriptions: ' + json.dumps(self._raw_data, - ensure_ascii=False) + return "Transcriptions: " + json.dumps( + self._raw_data, + ensure_ascii=False, + ) def __repr__(self): return self.__str__() class Translation(SentenceBaseObj): - def __init__(self, ) -> None: + def __init__(self) -> None: self.language: str = None self.stash: SentenceBaseObj = None self.is_sentence_end = False @@ -163,85 +168,87 @@ def __init__(self, ) -> None: @staticmethod def from_json(json_data: Dict[str, Any]): - """Create a Translation object from a JSON dictionary. - """ + """Create a Translation object from a JSON dictionary.""" translation = Translation() - translation.sentence_id = json_data['sentence_id'] - translation.text = json_data['text'] - translation.begin_time = json_data['begin_time'] - if json_data.get('end_time') is not None: - translation.end_time = json_data['end_time'] + translation.sentence_id = json_data["sentence_id"] + translation.text = json_data["text"] + translation.begin_time = json_data["begin_time"] + if json_data.get("end_time") is not None: + translation.end_time = json_data["end_time"] else: - translation.end_time = json_data['current_time'] + translation.end_time = json_data["current_time"] translation.words = [ - WordObj.from_json(word) for word in json_data['words'] + WordObj.from_json(word) for word in json_data["words"] ] - translation._raw_data = json_data - - translation.language = json_data['lang'] - translation.is_sentence_end = json_data.get('sentence_end') - if 'stash' in json_data: - translation.stash = SentenceBaseObj.from_json(json_data['stash']) - if 'vad_pre_end' in json_data: - translation.vad_pre_end = json_data['vad_pre_end'] - if 'pre_end_failed' in json_data: - translation.pre_end_failed = json_data['pre_end_failed'] - if 'pre_end_start_time' in json_data: - translation.pre_end_start_time = json_data['pre_end_start_time'] - if 'pre_end_end_time' in json_data: - translation.pre_end_end_time = json_data['pre_end_end_time'] - translation._raw_data = json_data + # Store raw JSON data for later use + translation._raw_data = json_data # pylint: disable=protected-access + + translation.language = json_data["lang"] + translation.is_sentence_end = json_data.get("sentence_end") + if "stash" in json_data: + translation.stash = SentenceBaseObj.from_json(json_data["stash"]) + if "vad_pre_end" in json_data: + translation.vad_pre_end = json_data["vad_pre_end"] + if "pre_end_failed" in json_data: + translation.pre_end_failed = json_data["pre_end_failed"] + if "pre_end_start_time" in json_data: + translation.pre_end_start_time = json_data["pre_end_start_time"] + if "pre_end_end_time" in json_data: + translation.pre_end_end_time = json_data["pre_end_end_time"] + translation._raw_data = json_data # pylint: disable=protected-access return translation def __str__(self) -> str: - return 'Translation: ' + json.dumps(self._raw_data, ensure_ascii=False) + return "Translation: " + json.dumps(self._raw_data, ensure_ascii=False) def __repr__(self): return self.__str__() -class TranslationResult(): - def __init__(self, ) -> None: - self.translations: Dict[str:Translation] = {} +class TranslationResult: + def __init__(self) -> None: + self.translations: Dict[str, Translation] = {} self.is_sentence_end = False self._raw_data = None def get_translation(self, language) -> Translation: if self.translations is None: return None - return self.translations.get(language) + return self.translations.get(language) # type: ignore[return-value] - def get_language_list(self, ) -> List[str]: + def get_language_list(self) -> List[str]: if self.translations is None: return None return list(self.translations.keys()) @staticmethod def from_json(json_data: List): - """Create a TranslationResult object from a JSON dictionary. - """ + """Create a TranslationResult object from a JSON dictionary.""" result = TranslationResult() - result._raw_data = json_data + # Store raw JSON data for later use + result._raw_data = json_data # pylint: disable=protected-access for translation_json in json_data: if not isinstance(translation_json, dict): raise InvalidParameter( - f'Invalid translation json data: {translation_json}') - else: - translation = Translation.from_json(translation_json) - result.translations[translation.language] = translation - if translation.is_sentence_end: - result.is_sentence_end = True + f"Invalid translation json data: {translation_json}", + ) + translation = Translation.from_json(translation_json) + result.translations[translation.language] = translation + if translation.is_sentence_end: + result.is_sentence_end = True return result def __str__(self) -> str: - return 'TranslationList: ' + json.dumps(self._raw_data, - ensure_ascii=False) + return "TranslationList: " + json.dumps( + self._raw_data, + ensure_ascii=False, + ) def __repr__(self): return self.__str__() -class TranslationRecognizerResultPack(): +class TranslationRecognizerResultPack: def __init__(self) -> None: self.transcription_result_list: List[TranscriptionResult] = [] self.translation_result_list: List[TranslationResult] = [] @@ -250,10 +257,11 @@ def __init__(self) -> None: self.error_message = None -class TranslationRecognizerCallback(): - """An interface that defines callback methods for getting translation recognizer results. # noqa E501 - Derive from this class and implement its function to provide your own data. +class TranslationRecognizerCallback: + """An interface that defines callback methods for getting translation recognizer results. # noqa E501 # pylint: disable=line-too-long + Derive from this class and implement its function to provide your own data. """ + def on_open(self) -> None: pass @@ -266,8 +274,13 @@ def on_error(self, message) -> None: def on_close(self) -> None: pass - def on_event(self, request_id, transcription_result: TranscriptionResult, - translation_result: TranslationResult, usage) -> None: + def on_event( + self, + request_id, + transcription_result: TranscriptionResult, + translation_result: TranslationResult, + usage, + ) -> None: pass @@ -301,22 +314,24 @@ class TranslationRecognizerRealtime(BaseApi): SILENCE_TIMEOUT_S = 23 - def __init__(self, - model: str, - callback: TranslationRecognizerCallback, - format: str, - sample_rate: int, - transcription_enabled: bool = True, - source_language: str = None, - translation_enabled: bool = False, - workspace: str = None, - **kwargs): + def __init__( + self, + model: str, + callback: TranslationRecognizerCallback, + format: str, # pylint: disable=redefined-builtin + sample_rate: int, + transcription_enabled: bool = True, + source_language: str = None, + translation_enabled: bool = False, + workspace: str = None, + **kwargs, + ): if model is None: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") if format is None: - raise InputRequired('format is required!') + raise InputRequired("format is required!") if sample_rate is None: - raise InputRequired('sample_rate is required!') + raise InputRequired("sample_rate is required!") self.model = model self.format = format @@ -346,7 +361,9 @@ def __del__(self): self._stream_data = Queue() if self._worker is not None and self._worker.is_alive(): self._worker.join() - if self._silence_timer is not None and self._silence_timer.is_alive( # noqa E501 + if ( + self._silence_timer is not None + and self._silence_timer.is_alive() # noqa E501 ): self._silence_timer.cancel() self._silence_timer = None @@ -354,18 +371,23 @@ def __del__(self): self._callback.on_close() def __receive_worker(self): - """Asynchronously, initiate a real-time transltion recognizer request and - obtain the result for parsing. + """Asynchronously, initiate a real-time transltion recognizer request and # noqa: E501 + obtain the result for parsing. """ responses = self.__launch_request() for part in responses: if part.status_code == HTTPStatus.OK: - logger.debug('Received response request_id: {} {}'.format( - part.request_id, part.output)) + logger.debug( + "Received response request_id: %s %s", + part.request_id, + part.output, + ) if len(part.output) == 0: self._on_complete_timestamp = time.time() * 1000 - logger.debug('last package delay {}'.format( - self.get_last_package_delay())) + logger.debug( + "last package delay %s", + self.get_last_package_delay(), + ) self._callback.on_complete() else: usage = None @@ -373,23 +395,34 @@ def __receive_worker(self): translations = None if DASHSCOPE_TRANSCRIPTION_KEY in part.output: transcription = TranscriptionResult.from_json( - part.output[DASHSCOPE_TRANSCRIPTION_KEY]) + part.output[DASHSCOPE_TRANSCRIPTION_KEY], + ) if DASHSCOPE_TRANSLATION_KEY in part.output: translations = TranslationResult.from_json( - part.output[DASHSCOPE_TRANSLATION_KEY]) + part.output[DASHSCOPE_TRANSLATION_KEY], + ) if transcription is not None or translations is not None: - if (self._first_package_timestamp < 0): + if self._first_package_timestamp < 0: self._first_package_timestamp = time.time() * 1000 - logger.debug('first package delay {}'.format( - self.get_first_package_delay())) + logger.debug( + "first package delay %s", + self.get_first_package_delay(), + ) if part.usage is not None: usage = part.usage - if self.request_id_confirmed is False and part.request_id is not None: + if ( + self.request_id_confirmed is False + and part.request_id is not None + ): self.last_request_id = part.request_id self.request_id_confirmed = True - self._callback.on_event(part.request_id, transcription, - translations, usage) + self._callback.on_event( + part.request_id, + transcription, + translations, + usage, + ) else: self._running = False self._stream_data = Queue() @@ -398,16 +431,15 @@ def __receive_worker(self): break def __launch_request(self): - """Initiate real-time translation recognizer requests. - """ + """Initiate real-time translation recognizer requests.""" self._tidy_kwargs() task_name, _ = _get_task_group_and_task(__name__) responses = super().call( model=self.model, - task_group='audio', + task_group="audio", task=task_name, - function='recognition', + function="recognition", input=self._input_stream_cycle(), api_protocol=ApiProtocol.WEBSOCKET, ws_stream_mode=WebsocketStreamingMode.DUPLEX, @@ -420,12 +452,13 @@ def __launch_request(self): translation_enabled=self.translation_enabled, workspace=self._workspace, pre_task_id=self.last_request_id, - **self._kwargs) + **self._kwargs, + ) return responses def start(self, **kwargs): """Real-time translation recognizer in asynchronous mode. - Please call 'stop()' after you have completed translation & recognition. + Please call 'stop()' after you have completed translation & recognition. # noqa: E501 Args: phrase_id (str, `optional`): The ID of phrase. @@ -447,11 +480,14 @@ def start(self, **kwargs): if it has already been started. InvalidTask: Task create failed. """ - assert self._callback is not None, 'Please set the callback to get the translation & recognition result.' # noqa E501 + assert ( + self._callback is not None + ), "Please set the callback to get the translation & recognition result." # noqa E501 if self._running: raise InvalidParameter( - 'TranslationRecognizerRealtime has started.') + "TranslationRecognizerRealtime has started.", + ) self._start_stream_timestamp = -1 self._first_package_timestamp = -1 @@ -468,16 +504,20 @@ def start(self, **kwargs): # If audio data is not received for 23 seconds, the timeout exits self._silence_timer = Timer( TranslationRecognizerRealtime.SILENCE_TIMEOUT_S, - self._silence_stop_timer) + self._silence_stop_timer, + ) self._silence_timer.start() else: self._running = False - raise InvalidTask('Invalid task, task create failed.') - - def call(self, - file: str, - phrase_id: str = None, - **kwargs) -> TranslationRecognizerResultPack: + raise InvalidTask("Invalid task, task create failed.") + + # pylint: disable=too-many-branches,too-many-statements + def call( # type: ignore[override] + self, + file: str, + phrase_id: str = None, + **kwargs, + ) -> TranslationRecognizerResultPack: """TranslationRecognizerRealtime in synchronous mode. Args: @@ -502,18 +542,19 @@ def call(self, InputDataRequired: The supplied file was empty. Returns: - TranslationRecognizerResultPack: The result of speech translation & recognition. + TranslationRecognizerResultPack: The result of speech translation & recognition. # noqa: E501 # pylint: disable=line-too-long """ self._start_stream_timestamp = time.time() * 1000 if self._running: raise InvalidParameter( - 'TranslationRecognizerRealtime has been called.') + "TranslationRecognizerRealtime has been called.", + ) if os.path.exists(file): if os.path.isdir(file): - raise IsADirectoryError('Is a directory: ' + file) + raise IsADirectoryError("Is a directory: " + file) else: - raise FileNotFoundError('No such file or directory: ' + file) + raise FileNotFoundError("No such file or directory: " + file) self._recognition_once = True self._stream_data = Queue() @@ -524,17 +565,20 @@ def call(self, try: audio_data: bytes = None - f = open(file, 'rb') + # pylint: disable=consider-using-with + f = open(file, "rb") if os.path.getsize(file): while True: audio_data = f.read(12800) if not audio_data: break - else: - self._stream_data.put(audio_data) + self._stream_data.put( + audio_data, + ) # pylint: disable=no-else-break else: raise InputDataRequired( - 'The supplied file was empty (zero bytes long)') + "The supplied file was empty (zero bytes long)", + ) f.close() self._stop_stream_timestamp = time.time() * 1000 except Exception as e: @@ -546,36 +590,47 @@ def call(self, responses = self.__launch_request() for part in responses: if part.status_code == HTTPStatus.OK: - logger.debug('received data: {}'.format(part.output)) + logger.debug("received data: %s", part.output) # debug log cal fpd transcription = None translation = None usage = None - if ('translation' in part.output) or ('transcription' - in part.output): - if (self._first_package_timestamp < 0): + if ("translation" in part.output) or ( + "transcription" in part.output + ): + if self._first_package_timestamp < 0: self._first_package_timestamp = time.time() * 1000 - logger.debug('first package delay {}'.format( - self._first_package_timestamp - - self._start_stream_timestamp)) + logger.debug( + "first package delay %s", + self._first_package_timestamp + - self._start_stream_timestamp, + ) if part.usage is not None: usage = part.usage if DASHSCOPE_TRANSCRIPTION_KEY in part.output: transcription = TranscriptionResult.from_json( - part.output[DASHSCOPE_TRANSCRIPTION_KEY]) + part.output[DASHSCOPE_TRANSCRIPTION_KEY], + ) if DASHSCOPE_TRANSLATION_KEY in part.output: translation = TranslationResult.from_json( - part.output[DASHSCOPE_TRANSLATION_KEY]) + part.output[DASHSCOPE_TRANSLATION_KEY], + ) - if (transcription is not None - and transcription.is_sentence_end) or ( - translation is not None - and translation.is_sentence_end): + if ( + transcription is not None + and transcription.is_sentence_end + ) or ( + translation is not None and translation.is_sentence_end + ): results.request_id = part.request_id - results.transcription_result_list.append(transcription) - results.translation_result_list.append(translation) + results.transcription_result_list.append( + transcription, # type: ignore[arg-type] + ) # noqa: E501 + results.translation_result_list.append( + translation, # type: ignore[arg-type] + ) # noqa: E501 results.usage_list.append(usage) else: error_message = part @@ -583,8 +638,10 @@ def call(self, break self._on_complete_timestamp = time.time() * 1000 - logger.debug('last package delay {}'.format( - self.get_last_package_delay())) + logger.debug( + "last package delay %s", + self.get_last_package_delay(), + ) self._stream_data = Queue() self._recognition_once = False @@ -596,11 +653,12 @@ def stop(self): """End asynchronous TranslationRecognizerRealtime. Raises: - InvalidParameter: Cannot stop an uninitiated TranslationRecognizerRealtime. + InvalidParameter: Cannot stop an uninitiated TranslationRecognizerRealtime. # noqa: E501 # pylint: disable=line-too-long """ if self._running is False: raise InvalidParameter( - 'TranslationRecognizerRealtime has stopped.') + "TranslationRecognizerRealtime has stopped.", + ) self._stop_stream_timestamp = time.time() * 1000 @@ -618,15 +676,16 @@ def send_audio_frame(self, buffer: bytes): """Push audio to TranslationRecognizerRealtime. Raises: - InvalidParameter: Cannot send data to an uninitiated TranslationRecognizerRealtime. + InvalidParameter: Cannot send data to an uninitiated TranslationRecognizerRealtime. # noqa: E501 # pylint: disable=line-too-long """ if self._running is False: raise InvalidParameter( - 'TranslationRecognizerRealtime has stopped.') + "TranslationRecognizerRealtime has stopped.", + ) - if (self._start_stream_timestamp < 0): + if self._start_stream_timestamp < 0: self._start_stream_timestamp = time.time() * 1000 - logger.debug('send_audio_frame: {}'.format(len(buffer))) + logger.debug("send_audio_frame: %s", len(buffer)) self._stream_data.put(buffer) def _tidy_kwargs(self): @@ -640,16 +699,18 @@ def _input_stream_cycle(self): if self._running: time.sleep(0.01) continue - else: - break + break # Reset silence_timer when getting stream. - if self._silence_timer is not None and self._silence_timer.is_alive( # noqa E501 + if ( + self._silence_timer is not None + and self._silence_timer.is_alive() # noqa E501 ): self._silence_timer.cancel() self._silence_timer = Timer( TranslationRecognizerRealtime.SILENCE_TIMEOUT_S, - self._silence_stop_timer) + self._silence_stop_timer, + ) self._silence_timer.start() while not self._stream_data.empty(): @@ -666,8 +727,7 @@ def _input_stream_cycle(self): yield bytes(frame) def _silence_stop_timer(self): - """If audio data is not received for a long time, exit worker. - """ + """If audio data is not received for a long time, exit worker.""" self._running = False if self._silence_timer is not None and self._silence_timer.is_alive(): self._silence_timer.cancel() @@ -677,13 +737,11 @@ def _silence_stop_timer(self): self._stream_data = Queue() def get_first_package_delay(self): - """First Package Delay is the time between start sending audio and receive first words package - """ + """First Package Delay is the time between start sending audio and receive first words package""" # noqa: E501 # pylint: disable=line-too-long return self._first_package_timestamp - self._start_stream_timestamp def get_last_package_delay(self): - """Last Package Delay is the time between stop sending audio and receive last words package - """ + """Last Package Delay is the time between stop sending audio and receive last words package""" # noqa: E501 # pylint: disable=line-too-long return self._on_complete_timestamp - self._stop_stream_timestamp # 获取上一个任务的taskId @@ -721,22 +779,24 @@ class TranslationRecognizerChat(BaseApi): SILENCE_TIMEOUT_S = 23 - def __init__(self, - model: str, - callback: TranslationRecognizerCallback, - format: str, - sample_rate: int, - transcription_enabled: bool = True, - source_language: str = None, - translation_enabled: bool = False, - workspace: str = None, - **kwargs): + def __init__( + self, + model: str, + callback: TranslationRecognizerCallback, + format: str, # pylint: disable=redefined-builtin + sample_rate: int, + transcription_enabled: bool = True, + source_language: str = None, + translation_enabled: bool = False, + workspace: str = None, + **kwargs, + ): if model is None: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") if format is None: - raise InputRequired('format is required!') + raise InputRequired("format is required!") if sample_rate is None: - raise InputRequired('sample_rate is required!') + raise InputRequired("sample_rate is required!") self.model = model self.format = format @@ -767,26 +827,33 @@ def __del__(self): self._stream_data = Queue() if self._worker is not None and self._worker.is_alive(): self._worker.join() - if self._silence_timer is not None and self._silence_timer.is_alive( # noqa E501 + if ( + self._silence_timer is not None + and self._silence_timer.is_alive() # noqa E501 ): self._silence_timer.cancel() self._silence_timer = None if self._callback: self._callback.on_close() - def __receive_worker(self): - """Asynchronously, initiate a real-time transltion recognizer request and - obtain the result for parsing. + def __receive_worker(self): # pylint: disable=too-many-branches + """Asynchronously, initiate a real-time transltion recognizer request and # noqa: E501 + obtain the result for parsing. """ responses = self.__launch_request() for part in responses: if part.status_code == HTTPStatus.OK: - logger.debug('Received response request_id: {} {}'.format( - part.request_id, part.output)) + logger.debug( + "Received response request_id: %s %s", + part.request_id, + part.output, + ) if len(part.output) == 0: self._on_complete_timestamp = time.time() * 1000 - logger.debug('last package delay {}'.format( - self.get_last_package_delay())) + logger.debug( + "last package delay %s", + self.get_last_package_delay(), + ) self._callback.on_complete() else: usage = None @@ -794,33 +861,50 @@ def __receive_worker(self): translations = None if DASHSCOPE_TRANSCRIPTION_KEY in part.output: transcription = TranscriptionResult.from_json( - part.output[DASHSCOPE_TRANSCRIPTION_KEY]) + part.output[DASHSCOPE_TRANSCRIPTION_KEY], + ) if DASHSCOPE_TRANSLATION_KEY in part.output: translations = TranslationResult.from_json( - part.output[DASHSCOPE_TRANSLATION_KEY]) + part.output[DASHSCOPE_TRANSLATION_KEY], + ) if transcription is not None or translations is not None: - if (self._first_package_timestamp < 0): + if self._first_package_timestamp < 0: self._first_package_timestamp = time.time() * 1000 - logger.debug('first package delay {}'.format( - self.get_first_package_delay())) + logger.debug( + "first package delay %s", + self.get_first_package_delay(), + ) if part.usage is not None: usage = part.usage - if self.request_id_confirmed is False and part.request_id is not None: + if ( + self.request_id_confirmed is False + and part.request_id is not None + ): self.last_request_id = part.request_id self.request_id_confirmed = True - if transcription is not None and transcription.is_sentence_end: + if ( + transcription is not None + and transcription.is_sentence_end + ): logger.debug( - '[Chat] recv sentence end in transcription, stop asr' + "[Chat] recv sentence end in transcription, stop asr", # noqa: E501 ) self._is_sentence_end.set(True) - if translations is not None and translations.is_sentence_end: + if ( + translations is not None + and translations.is_sentence_end + ): logger.debug( - '[Chat] recv sentence end in translation, stop asr' + "[Chat] recv sentence end in translation, stop asr", # noqa: E501 ) self._is_sentence_end.set(True) - self._callback.on_event(part.request_id, transcription, - translations, usage) + self._callback.on_event( + part.request_id, + transcription, + translations, + usage, + ) else: self._running = False self._stream_data = Queue() @@ -829,16 +913,15 @@ def __receive_worker(self): break def __launch_request(self): - """Initiate real-time translation recognizer requests. - """ + """Initiate real-time translation recognizer requests.""" self._tidy_kwargs() task_name, _ = _get_task_group_and_task(__name__) responses = super().call( model=self.model, - task_group='audio', + task_group="audio", task=task_name, - function='recognition', + function="recognition", input=self._input_stream_cycle(), api_protocol=ApiProtocol.WEBSOCKET, ws_stream_mode=WebsocketStreamingMode.DUPLEX, @@ -851,12 +934,13 @@ def __launch_request(self): translation_enabled=self.translation_enabled, workspace=self._workspace, pre_task_id=self.last_request_id, - **self._kwargs) + **self._kwargs, + ) return responses def start(self, **kwargs): """Real-time translation recognizer in asynchronous mode. - Please call 'stop()' after you have completed translation & recognition. + Please call 'stop()' after you have completed translation & recognition. # noqa: E501 Args: phrase_id (str, `optional`): The ID of phrase. @@ -878,10 +962,12 @@ def start(self, **kwargs): if it has already been started. InvalidTask: Task create failed. """ - assert self._callback is not None, 'Please set the callback to get the translation & recognition result.' # noqa E501 + assert ( + self._callback is not None + ), "Please set the callback to get the translation & recognition result." # noqa E501 if self._running: - raise InvalidParameter('TranslationRecognizerChat has started.') + raise InvalidParameter("TranslationRecognizerChat has started.") self._start_stream_timestamp = -1 self._first_package_timestamp = -1 @@ -898,23 +984,24 @@ def start(self, **kwargs): # If audio data is not received for 23 seconds, the timeout exits self._silence_timer = Timer( TranslationRecognizerChat.SILENCE_TIMEOUT_S, - self._silence_stop_timer) + self._silence_stop_timer, + ) self._silence_timer.start() else: self._running = False - raise InvalidTask('Invalid task, task create failed.') + raise InvalidTask("Invalid task, task create failed.") def stop(self): """End asynchronous TranslationRecognizerChat. Raises: - InvalidParameter: Cannot stop an uninitiated TranslationRecognizerChat. + InvalidParameter: Cannot stop an uninitiated TranslationRecognizerChat. # noqa: E501 """ if self._running is False: - raise InvalidParameter('TranslationRecognizerChat has stopped.') + raise InvalidParameter("TranslationRecognizerChat has stopped.") self._stop_stream_timestamp = time.time() * 1000 - logger.debug('stop TranslationRecognizerChat') + logger.debug("stop TranslationRecognizerChat") self._running = False if self._worker is not None and self._worker.is_alive(): self._worker.join() @@ -929,18 +1016,18 @@ def send_audio_frame(self, buffer: bytes) -> bool: """Push audio to TranslationRecognizerChat. Raises: - InvalidParameter: Cannot send data to an uninitiated TranslationRecognizerChat. + InvalidParameter: Cannot send data to an uninitiated TranslationRecognizerChat. # noqa: E501 # pylint: disable=line-too-long """ if self._is_sentence_end.get(): - logger.debug('skip audio due to has sentence end.') + logger.debug("skip audio due to has sentence end.") return False if self._running is False: - raise InvalidParameter('TranslationRecognizerChat has stopped.') + raise InvalidParameter("TranslationRecognizerChat has stopped.") - if (self._start_stream_timestamp < 0): + if self._start_stream_timestamp < 0: self._start_stream_timestamp = time.time() * 1000 - logger.debug('send_audio_frame: {}'.format(len(buffer))) + logger.debug("send_audio_frame: %s", len(buffer)) self._stream_data.put(buffer) return True @@ -955,16 +1042,18 @@ def _input_stream_cycle(self): if self._running: time.sleep(0.01) continue - else: - break + break # Reset silence_timer when getting stream. - if self._silence_timer is not None and self._silence_timer.is_alive( # noqa E501 + if ( + self._silence_timer is not None + and self._silence_timer.is_alive() # noqa E501 ): self._silence_timer.cancel() self._silence_timer = Timer( TranslationRecognizerChat.SILENCE_TIMEOUT_S, - self._silence_stop_timer) + self._silence_stop_timer, + ) self._silence_timer.start() while not self._stream_data.empty(): @@ -981,8 +1070,7 @@ def _input_stream_cycle(self): yield bytes(frame) def _silence_stop_timer(self): - """If audio data is not received for a long time, exit worker. - """ + """If audio data is not received for a long time, exit worker.""" self._running = False if self._silence_timer is not None and self._silence_timer.is_alive(): self._silence_timer.cancel() @@ -992,13 +1080,11 @@ def _silence_stop_timer(self): self._stream_data = Queue() def get_first_package_delay(self): - """First Package Delay is the time between start sending audio and receive first words package - """ + """First Package Delay is the time between start sending audio and receive first words package""" # noqa: E501 # pylint: disable=line-too-long return self._first_package_timestamp - self._start_stream_timestamp def get_last_package_delay(self): - """Last Package Delay is the time between stop sending audio and receive last words package - """ + """Last Package Delay is the time between stop sending audio and receive last words package""" # noqa: E501 # pylint: disable=line-too-long return self._on_complete_timestamp - self._stop_stream_timestamp # 获取上一个任务的taskId diff --git a/dashscope/audio/asr/vocabulary.py b/dashscope/audio/asr/vocabulary.py index 8597ebd..1e96e28 100644 --- a/dashscope/audio/asr/vocabulary.py +++ b/dashscope/audio/asr/vocabulary.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio @@ -12,28 +13,36 @@ class VocabularyServiceException(Exception): - def __init__(self, request_id: str, status_code: int, code: str, - error_message: str) -> None: + def __init__( + self, + request_id: str, + status_code: int, + code: str, + error_message: str, + ) -> None: self._request_id = request_id self._status_code = status_code self._code = code self._error_message = error_message def __str__(self): - return f'Request: {self._request_id}, Status Code: {self._status_code}, Code: {self._code}, Error Message: {self._error_message}' + return f"Request: {self._request_id}, Status Code: {self._status_code}, Code: {self._code}, Error Message: {self._error_message}" # noqa: E501 # pylint: disable=line-too-long class VocabularyService(BaseApi): - ''' + """ API for asr vocabulary service - ''' + """ + MAX_QUERY_TRY_COUNT = 3 - def __init__(self, - api_key=None, - workspace=None, - model=None, - **kwargs) -> None: + def __init__( + self, + api_key=None, + workspace=None, + model=None, + **kwargs, + ) -> None: super().__init__() self._api_key = api_key self._workspace = workspace @@ -41,22 +50,24 @@ def __init__(self, self._last_request_id = None self.model = model if self.model is None: - self.model = 'speech-biasing' + self.model = "speech-biasing" - def __call_with_input(self, input): + def __call_with_input(self, input): # pylint: disable=redefined-builtin try_count = 0 while True: try: - response = super().call(model=self.model, - task_group='audio', - task='asr', - function='customization', - input=input, - api_protocol=ApiProtocol.HTTP, - http_method=HTTPMethod.POST, - api_key=self._api_key, - workspace=self._workspace, - **self._kwargs) + response = super().call( + model=self.model, + task_group="audio", + task="asr", + function="customization", + input=input, + api_protocol=ApiProtocol.HTTP, + http_method=HTTPMethod.POST, + api_key=self._api_key, + workspace=self._workspace, + **self._kwargs, + ) except (asyncio.TimeoutError, aiohttp.ClientConnectorError) as e: logger.error(e) try_count += 1 @@ -65,113 +76,160 @@ def __call_with_input(self, input): continue break - logger.debug('>>>>recv', response) + logger.debug(">>>>recv %s", response) return response - def create_vocabulary(self, target_model: str, prefix: str, - vocabulary: List[dict]) -> str: - ''' + def create_vocabulary( + self, + target_model: str, + prefix: str, + vocabulary: List[dict], + ) -> str: + """ 创建热词表 param: target_model 热词表对应的语音识别模型版本 param: prefix 热词表自定义前缀,仅允许数字和小写字母,小于十个字符。 param: vocabulary 热词表字典 return: 热词表标识符 vocabulary_id - ''' - response = self.__call_with_input(input={ - 'action': 'create_vocabulary', - 'target_model': target_model, - 'prefix': prefix, - 'vocabulary': vocabulary, - }, ) + """ + # pylint: disable=no-value-for-parameter + response = self.__call_with_input( + input={ + "action": "create_vocabulary", + "target_model": target_model, + "prefix": prefix, + "vocabulary": vocabulary, + }, + ) if response.status_code == 200: self._last_request_id = response.request_id - return response.output['vocabulary_id'] + return response.output["vocabulary_id"] else: - raise VocabularyServiceException(response.request_id, response.status_code, - response.code, response.message) - - def list_vocabularies(self, - prefix=None, - page_index: int = 0, - page_size: int = 10) -> List[dict]: - ''' + raise VocabularyServiceException( + response.request_id, + response.status_code, + response.code, + response.message, + ) + + def list_vocabularies( + self, + prefix=None, + page_index: int = 0, + page_size: int = 10, + ) -> List[dict]: + """ 查询已创建的所有热词表 param: prefix 自定义前缀,如果设定则只返回指定前缀的热词表标识符列表。 param: page_index 查询的页索引 param: page_size 查询页大小 return: 热词表标识符列表 - ''' + """ if prefix: - response = self.__call_with_input(input={ - 'action': 'list_vocabulary', - 'prefix': prefix, - 'page_index': page_index, - 'page_size': page_size, - }, ) + # pylint: disable=no-value-for-parameter + response = self.__call_with_input( + input={ + "action": "list_vocabulary", + "prefix": prefix, + "page_index": page_index, + "page_size": page_size, + }, + ) else: - response = self.__call_with_input(input={ - 'action': 'list_vocabulary', - 'page_index': page_index, - 'page_size': page_size, - }, ) + # pylint: disable=no-value-for-parameter + response = self.__call_with_input( + input={ + "action": "list_vocabulary", + "page_index": page_index, + "page_size": page_size, + }, + ) if response.status_code == 200: self._last_request_id = response.request_id - return response.output['vocabulary_list'] + return response.output["vocabulary_list"] else: - raise VocabularyServiceException(response.request_id, response.status_code, - response.code, response.message) + raise VocabularyServiceException( + response.request_id, + response.status_code, + response.code, + response.message, + ) def query_vocabulary(self, vocabulary_id: str) -> List[dict]: - ''' + """ 获取热词表内容 param: vocabulary_id 热词表标识符 return: 热词表 - ''' - response = self.__call_with_input(input={ - 'action': 'query_vocabulary', - 'vocabulary_id': vocabulary_id, - }, ) + """ + # pylint: disable=no-value-for-parameter + response = self.__call_with_input( + input={ + "action": "query_vocabulary", + "vocabulary_id": vocabulary_id, + }, + ) if response.status_code == 200: self._last_request_id = response.request_id return response.output else: - raise VocabularyServiceException(response.request_id, response.status_code, - response.code, response.message) - - def update_vocabulary(self, vocabulary_id: str, - vocabulary: List[dict]) -> None: - ''' + raise VocabularyServiceException( + response.request_id, + response.status_code, + response.code, + response.message, + ) + + def update_vocabulary( + self, + vocabulary_id: str, + vocabulary: List[dict], + ) -> None: + """ 用新的热词表替换已有热词表 param: vocabulary_id 需要替换的热词表标识符 param: vocabulary 热词表 - ''' - response = self.__call_with_input(input={ - 'action': 'update_vocabulary', - 'vocabulary_id': vocabulary_id, - 'vocabulary': vocabulary, - }, ) + """ + # pylint: disable=no-value-for-parameter + response = self.__call_with_input( + input={ + "action": "update_vocabulary", + "vocabulary_id": vocabulary_id, + "vocabulary": vocabulary, + }, + ) if response.status_code == 200: self._last_request_id = response.request_id return else: - raise VocabularyServiceException(response.request_id, response.status_code, - response.code, response.message) + raise VocabularyServiceException( + response.request_id, + response.status_code, + response.code, + response.message, + ) def delete_vocabulary(self, vocabulary_id: str) -> None: - ''' + """ 删除热词表 param: vocabulary_id 需要删除的热词表标识符 - ''' - response = self.__call_with_input(input={ - 'action': 'delete_vocabulary', - 'vocabulary_id': vocabulary_id, - }, ) + """ + # pylint: disable=no-value-for-parameter + response = self.__call_with_input( + input={ + "action": "delete_vocabulary", + "vocabulary_id": vocabulary_id, + }, + ) if response.status_code == 200: self._last_request_id = response.request_id return else: - raise VocabularyServiceException(response.request_id, response.status_code, - response.code, response.message) + raise VocabularyServiceException( + response.request_id, + response.status_code, + response.code, + response.message, + ) def get_last_request_id(self): return self._last_request_id diff --git a/dashscope/audio/qwen_asr/__init__.py b/dashscope/audio/qwen_asr/__init__.py index 4cd207a..06a2707 100644 --- a/dashscope/audio/qwen_asr/__init__.py +++ b/dashscope/audio/qwen_asr/__init__.py @@ -1,7 +1,8 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -from .qwen_transcription import (QwenTranscription) +from .qwen_transcription import QwenTranscription __all__ = [ - 'QwenTranscription', + "QwenTranscription", ] diff --git a/dashscope/audio/qwen_asr/qwen_transcription.py b/dashscope/audio/qwen_asr/qwen_transcription.py index 013237c..9682ae2 100644 --- a/dashscope/audio/qwen_asr/qwen_transcription.py +++ b/dashscope/audio/qwen_asr/qwen_transcription.py @@ -1,31 +1,35 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio import time -from typing import List, Union +from typing import Union import aiohttp -from dashscope.api_entities.dashscope_response import (DashScopeAPIResponse, - TranscriptionResponse) +from dashscope.api_entities.dashscope_response import ( + DashScopeAPIResponse, + TranscriptionResponse, +) from dashscope.client.base_api import BaseAsyncApi from dashscope.common.constants import ApiProtocol, HTTPMethod from dashscope.common.logging import logger class QwenTranscription(BaseAsyncApi): - """API for File Transcription models. - """ + """API for File Transcription models.""" MAX_QUERY_TRY_COUNT = 3 @classmethod - def call(cls, - model: str, - file_url: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> TranscriptionResponse: + def call( # type: ignore[override] + cls, + model: str, + file_url: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> TranscriptionResponse: """Transcribe the given files synchronously. Args: @@ -37,20 +41,24 @@ def call(cls, TranscriptionResponse: The result of batch transcription. """ kwargs = cls._tidy_kwargs(**kwargs) - response = super().call(model, - file_url, - api_key=api_key, - workspace=workspace, - **kwargs) + response = super().call( + model, + file_url, + api_key=api_key, + workspace=workspace, + **kwargs, + ) return TranscriptionResponse.from_api_response(response) @classmethod - def async_call(cls, - model: str, - file_url: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> TranscriptionResponse: + def async_call( # type: ignore[override] + cls, + model: str, + file_url: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> TranscriptionResponse: """Transcribe the given files asynchronously, return the status of task submission for querying results subsequently. @@ -63,20 +71,24 @@ def async_call(cls, TranscriptionResponse: The response including task_id. """ kwargs = cls._tidy_kwargs(**kwargs) - response = cls._launch_request(model, - file_url, - api_key=api_key, - workspace=workspace, - **kwargs) + response = cls._launch_request( + model, + file_url, + api_key=api_key, + workspace=workspace, + **kwargs, + ) return TranscriptionResponse.from_api_response(response) @classmethod - def fetch(cls, - task: Union[str, TranscriptionResponse], - api_key: str = None, - workspace: str = None, - **kwargs) -> TranscriptionResponse: - """Fetch the status of task, including results of batch transcription when task_status is SUCCEEDED. # noqa: E501 + def fetch( + cls, + task: Union[str, TranscriptionResponse], # type: ignore[override] + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> TranscriptionResponse: + """Fetch the status of task, including results of batch transcription when task_status is SUCCEEDED. # noqa: E501 # pylint: disable=line-too-long Args: task (Union[str, TranscriptionResponse]): The task_id or @@ -90,10 +102,12 @@ def fetch(cls, try_count: int = 0 while True: try: - response = super().fetch(task, - api_key=api_key, - workspace=workspace, - **kwargs) + response = super().fetch( + task, + api_key=api_key, + workspace=workspace, + **kwargs, + ) except (asyncio.TimeoutError, aiohttp.ClientConnectorError) as e: logger.error(e) try_count += 1 @@ -107,11 +121,13 @@ def fetch(cls, return TranscriptionResponse.from_api_response(response) @classmethod - def wait(cls, - task: Union[str, TranscriptionResponse], - api_key: str = None, - workspace: str = None, - **kwargs) -> TranscriptionResponse: + def wait( + cls, + task: Union[str, TranscriptionResponse], # type: ignore[override] + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> TranscriptionResponse: """Poll task until the final results of transcription is obtained. Args: @@ -122,19 +138,23 @@ def wait(cls, Returns: TranscriptionResponse: The result of batch transcription. """ - response = super().wait(task, - api_key=api_key, - workspace=workspace, - **kwargs) + response = super().wait( + task, + api_key=api_key, + workspace=workspace, + **kwargs, + ) return TranscriptionResponse.from_api_response(response) @classmethod - def _launch_request(cls, - model: str, - file: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def _launch_request( + cls, + model: str, + file: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Submit transcribe request. Args: @@ -149,16 +169,18 @@ def _launch_request(cls, try_count: int = 0 while True: try: - response = super().async_call(model=model, - task_group='audio', - task='asr', - function='transcription', - input={'file_url': file}, - api_protocol=ApiProtocol.HTTP, - http_method=HTTPMethod.POST, - api_key=api_key, - workspace=workspace, - **kwargs) + response = super().async_call( + model=model, + task_group="audio", + task="asr", + function="transcription", + input={"file_url": file}, + api_protocol=ApiProtocol.HTTP, + http_method=HTTPMethod.POST, + api_key=api_key, + workspace=workspace, + **kwargs, + ) except (asyncio.TimeoutError, aiohttp.ClientConnectorError) as e: logger.error(e) try_count += 1 @@ -173,11 +195,11 @@ def _launch_request(cls, def _fill_resource_id(cls, phrase_id: str, **kwargs): resources_list: list = [] if phrase_id is not None and len(phrase_id) > 0: - item = {'resource_id': phrase_id, 'resource_type': 'asr_phrase'} + item = {"resource_id": phrase_id, "resource_type": "asr_phrase"} resources_list.append(item) if len(resources_list) > 0: - kwargs['resources'] = resources_list + kwargs["resources"] = resources_list return kwargs diff --git a/dashscope/audio/qwen_omni/__init__.py b/dashscope/audio/qwen_omni/__init__.py index a7b2d83..629fba1 100644 --- a/dashscope/audio/qwen_omni/__init__.py +++ b/dashscope/audio/qwen_omni/__init__.py @@ -1,11 +1,16 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -from .omni_realtime import (AudioFormat, MultiModality, OmniRealtimeCallback, - OmniRealtimeConversation) +from .omni_realtime import ( + AudioFormat, + MultiModality, + OmniRealtimeCallback, + OmniRealtimeConversation, +) __all__ = [ - 'OmniRealtimeCallback', - 'AudioFormat', - 'MultiModality', - 'OmniRealtimeConversation', + "OmniRealtimeCallback", + "AudioFormat", + "MultiModality", + "OmniRealtimeConversation", ] diff --git a/dashscope/audio/qwen_omni/omni_realtime.py b/dashscope/audio/qwen_omni/omni_realtime.py index a654201..22ec9cf 100644 --- a/dashscope/audio/qwen_omni/omni_realtime.py +++ b/dashscope/audio/qwen_omni/omni_realtime.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -9,9 +10,10 @@ import uuid from enum import Enum, unique +import websocket # pylint: disable=wrong-import-order + import dashscope -import websocket -from dashscope.common.error import InputRequired, ModelRequired +from dashscope.common.error import ModelRequired from dashscope.common.logging import logger @@ -20,6 +22,7 @@ class OmniRealtimeCallback: An interface that defines callback methods for getting omni-realtime results. # noqa E501 Derive from this class and implement its function to provide your own data. """ + def on_open(self) -> None: pass @@ -38,7 +41,7 @@ class TranslationParams: @dataclass class Corpus: - phrases: Dict[str, Any] = field(default=None) + phrases: Dict[str, Any] = field(default=None) # type: ignore[arg-type] language: str = field(default=None) corpus: Corpus = field(default=None) @@ -49,20 +52,28 @@ class TranscriptionParams: """ TranscriptionParams """ + language: str = field(default=None) sample_rate: int = field(default=16000) input_audio_format: str = field(default="pcm") - corpus: Dict[str, Any] = field(default=None) + corpus: Dict[str, Any] = field(default=None) # type: ignore[arg-type] corpus_text: str = field(default=None) @unique class AudioFormat(Enum): # format, sample_rate, channels, bit_rate, name - PCM_16000HZ_MONO_16BIT = ('pcm', 16000, 'mono', '16bit', 'pcm16') - PCM_24000HZ_MONO_16BIT = ('pcm', 24000, 'mono', '16bit', 'pcm16') + PCM_16000HZ_MONO_16BIT = ("pcm", 16000, "mono", "16bit", "pcm16") + PCM_24000HZ_MONO_16BIT = ("pcm", 24000, "mono", "16bit", "pcm16") - def __init__(self, format, sample_rate, channels, bit_rate, format_str): + def __init__( # pylint: disable=redefined-builtin + self, + format, + sample_rate, + channels, + bit_rate, + format_str, + ): self.format = format self.sample_rate = sample_rate self.channels = channels @@ -73,15 +84,16 @@ def __repr__(self): return self.format_str def __str__(self): - return f'{self.format.upper()} with {self.sample_rate}Hz sample rate, {self.channels} channel, {self.bit_rate} bit rate: {self.format_str}' + return f"{self.format.upper()} with {self.sample_rate}Hz sample rate, {self.channels} channel, {self.bit_rate} bit rate: {self.format_str}" # noqa: E501 # pylint: disable=line-too-long class MultiModality(Enum): """ MultiModality """ - TEXT = 'text' - AUDIO = 'audio' + + TEXT = "text" + AUDIO = "audio" def __str__(self): return self.name @@ -96,7 +108,7 @@ def __init__( workspace=None, url=None, api_key: str = None, - additional_params=None, + additional_params=None, # pylint: disable=unused-argument ): """ Qwen Omni Realtime SDK @@ -117,13 +129,13 @@ def __init__( """ if model is None: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") if callback is None: - raise ModelRequired('Callback is required!') + raise ModelRequired("Callback is required!") if url is None: - url = f'wss://dashscope.aliyuncs.com/api-ws/v1/realtime?model={model}' + url = f"wss://dashscope.aliyuncs.com/api-ws/v1/realtime?model={model}" # noqa: E501 else: - url = f'{url}?model={model}' + url = f"{url}?model={model}" self.url = url self.apikey = api_key or dashscope.api_key self.user_headers = headers @@ -143,35 +155,34 @@ def __init__( self.disconnect_event = None def _generate_event_id(self): - ''' + """ generate random event id: event_xxxx - ''' - return 'event_' + uuid.uuid4().hex - - def _get_websocket_header(self, ): - ua = 'dashscope/%s; python/%s; platform/%s; processor/%s' % ( - '1.18.0', # dashscope version - platform.python_version(), - platform.platform(), - platform.processor(), + """ + return "event_" + uuid.uuid4().hex + + def _get_websocket_header(self): + ua = ( + f"dashscope/1.18.0; python/{platform.python_version()}; " + f"platform/{platform.platform()}; " + f"processor/{platform.processor()}" ) headers = { - 'user-agent': ua, - 'Authorization': 'bearer ' + self.apikey, + "user-agent": ua, + "Authorization": "bearer " + self.apikey, } if self.user_headers: headers = {**self.user_headers, **headers} if self.user_workspace: headers = { **headers, - 'X-DashScope-WorkSpace': self.user_workspace, + "X-DashScope-WorkSpace": self.user_workspace, } return headers def connect(self) -> None: - ''' - connect to server, create session and return default session configuration - ''' + """ + connect to server, create session and return default session configuration # noqa: E501 + """ self.ws = websocket.WebSocketApp( self.url, header=self._get_websocket_header(), @@ -184,40 +195,42 @@ def connect(self) -> None: self.thread.start() timeout = 5 # 最长等待时间(秒) start_time = time.time() - while (not (self.ws.sock and self.ws.sock.connected) - and (time.time() - start_time) < timeout): + while ( + not (self.ws.sock and self.ws.sock.connected) + and (time.time() - start_time) < timeout + ): time.sleep(0.1) # 短暂休眠,避免密集轮询 if not (self.ws.sock and self.ws.sock.connected): raise TimeoutError( - 'websocket connection could not established within 5s. ' - 'Please check your network connection, firewall settings, or server status.' + "websocket connection could not established within 5s. " + "Please check your network connection, firewall settings, or server status.", # noqa: E501 # pylint: disable=line-too-long ) self.callback.on_open() def __send_str(self, data: str, enable_log: bool = True): if enable_log: - logger.debug('[omni realtime] send string: {}'.format(data)) + logger.debug("[omni realtime] send string: %s", data) self.ws.send(data) - def update_session(self, - output_modalities: List[MultiModality], - voice: str = None, - input_audio_format: AudioFormat = AudioFormat. - PCM_16000HZ_MONO_16BIT, - output_audio_format: AudioFormat = AudioFormat. - PCM_24000HZ_MONO_16BIT, - enable_input_audio_transcription: bool = True, - input_audio_transcription_model: str = None, - enable_turn_detection: bool = True, - turn_detection_type: str = 'server_vad', - prefix_padding_ms: int = 300, - turn_detection_threshold: float = 0.2, - turn_detection_silence_duration_ms: int = 800, - turn_detection_param: dict = None, - translation_params: TranslationParams = None, - transcription_params: TranscriptionParams = None, - **kwargs) -> None: - ''' + def update_session( + self, + output_modalities: List[MultiModality], + voice: str = None, + input_audio_format: AudioFormat = AudioFormat.PCM_16000HZ_MONO_16BIT, + output_audio_format: AudioFormat = AudioFormat.PCM_24000HZ_MONO_16BIT, + enable_input_audio_transcription: bool = True, + input_audio_transcription_model: str = None, + enable_turn_detection: bool = True, + turn_detection_type: str = "server_vad", + prefix_padding_ms: int = 300, + turn_detection_threshold: float = 0.2, + turn_detection_silence_duration_ms: int = 800, + turn_detection_param: dict = None, + translation_params: TranslationParams = None, + transcription_params: TranscriptionParams = None, + **kwargs, + ) -> None: + """ update session configuration, should be used before create response Parameters @@ -234,69 +247,81 @@ def update_session(self, enable turn detection turn_detection_threshold: float turn detection threshold, range [-1, 1] - In a noisy environment, it may be necessary to increase the threshold to reduce false detections - In a quiet environment, it may be necessary to decrease the threshold to improve sensitivity + In a noisy environment, it may be necessary to increase the threshold to reduce false detections # noqa: E501 # pylint: disable=line-too-long + In a quiet environment, it may be necessary to decrease the threshold to improve sensitivity # noqa: E501 # pylint: disable=line-too-long turn_detection_silence_duration_ms: int - duration of silence in milliseconds to detect turn, range [200, 6000] + duration of silence in milliseconds to detect turn, range [200, 6000] # noqa: E501 translation_params: TranslationParams - translation params, include language. Only effective with qwen3-livetranslate-flash-realtime model or + translation params, include language. Only effective with qwen3-livetranslate-flash-realtime model or # noqa: E501 # pylint: disable=line-too-long further models. Do not set this parameter for other models. transcription_params: TranscriptionParams - transcription params, include language, sample_rate, input_audio_format, corpus. + transcription params, include language, sample_rate, input_audio_format, corpus. # noqa: E501 # pylint: disable=line-too-long Only effective with qwen3-asr-flash-realtime model or further models. Do not set this parameter for other models. - ''' + """ self.config = { - 'modalities': [m.value for m in output_modalities], - 'voice': voice, - 'input_audio_format': input_audio_format.format_str, - 'output_audio_format': output_audio_format.format_str, + "modalities": [m.value for m in output_modalities], + "voice": voice, + "input_audio_format": input_audio_format.format_str, + "output_audio_format": output_audio_format.format_str, } if enable_input_audio_transcription: - self.config['input_audio_transcription'] = { - 'model': input_audio_transcription_model, + self.config["input_audio_transcription"] = { + "model": input_audio_transcription_model, } else: - self.config['input_audio_transcription'] = None + self.config["input_audio_transcription"] = None if enable_turn_detection: - self.config['turn_detection'] = { - 'type': turn_detection_type, - 'threshold': turn_detection_threshold, - 'prefix_padding_ms': prefix_padding_ms, - 'silence_duration_ms': turn_detection_silence_duration_ms, + self.config["turn_detection"] = { + "type": turn_detection_type, + "threshold": turn_detection_threshold, + "prefix_padding_ms": prefix_padding_ms, + "silence_duration_ms": turn_detection_silence_duration_ms, } if turn_detection_param is not None: - self.config['turn_detection'].update(turn_detection_param) + self.config["turn_detection"].update(turn_detection_param) else: - self.config['turn_detection'] = None + self.config["turn_detection"] = None if translation_params is not None: - self.config['translation'] = { - 'language': translation_params.language, + self.config["translation"] = { + "language": translation_params.language, } if translation_params.corpus is not None: - if translation_params.corpus and translation_params.corpus.phrases is not None: - self.config['translation']['corpus'] = { - 'phrases': translation_params.corpus.phrases + if ( + translation_params.corpus + and translation_params.corpus.phrases is not None + ): + self.config["translation"]["corpus"] = { + "phrases": translation_params.corpus.phrases, } if transcription_params is not None: - self.config['input_audio_transcription'] = {} + self.config["input_audio_transcription"] = {} if transcription_params.language is not None: - self.config['input_audio_transcription'].update({'language': transcription_params.language}) + self.config["input_audio_transcription"].update( + {"language": transcription_params.language}, + ) if transcription_params.corpus_text is not None: transcription_params.corpus = { - "text": transcription_params.corpus_text + "text": transcription_params.corpus_text, } if transcription_params.corpus is not None: - self.config['input_audio_transcription'].update({'corpus': transcription_params.corpus}) - self.config['input_audio_format'] = transcription_params.input_audio_format - self.config['sample_rate'] = transcription_params.sample_rate + self.config["input_audio_transcription"].update( + {"corpus": transcription_params.corpus}, + ) + self.config[ + "input_audio_format" + ] = transcription_params.input_audio_format + self.config["sample_rate"] = transcription_params.sample_rate self.config.update(kwargs) self.__send_str( - json.dumps({ - 'event_id': self._generate_event_id(), - 'type': 'session.update', - 'session': self.config - })) + json.dumps( + { + "event_id": self._generate_event_id(), + "type": "session.update", + "session": self.config, + }, + ), + ) def end_session(self, timeout: int = 20) -> None: """ @@ -305,7 +330,7 @@ def end_session(self, timeout: int = 20) -> None: Parameters: ----------- timeout: int - Timeout in seconds to wait for the session to end. Default is 20 seconds. + Timeout in seconds to wait for the session to end. Default is 20 seconds. # noqa: E501 """ if self.disconnect_event is not None: # if the event is already set, do nothing @@ -315,10 +340,13 @@ def end_session(self, timeout: int = 20) -> None: self.disconnect_event = threading.Event() self.__send_str( - json.dumps({ - 'event_id': self._generate_event_id(), - 'type': 'session.finish' - })) + json.dumps( + { + "event_id": self._generate_event_id(), + "type": "session.finish", + }, + ), + ) # wait for the event to be set finish_success = self.disconnect_event.wait(timeout) @@ -328,79 +356,100 @@ def end_session(self, timeout: int = 20) -> None: # if the event is not set, close the connection if not finish_success: self.close() - raise TimeoutError("Session end timeout after {} seconds".format(timeout)) + raise TimeoutError( + f"Session end timeout after {timeout} seconds", + ) - def end_session_async(self, ) -> None: + def end_session_async(self) -> None: """ end session asynchronously. you need close the connection manually """ # 发送结束会话消息 self.__send_str( - json.dumps({ - 'event_id': self._generate_event_id(), - 'type': 'session.finish' - })) + json.dumps( + { + "event_id": self._generate_event_id(), + "type": "session.finish", + }, + ), + ) def append_audio(self, audio_b64: str) -> None: - ''' + """ send audio in base64 format Parameters ---------- audio_b64: str base64 audio string - ''' - logger.debug('[omni realtime] append audio: {}'.format(len(audio_b64))) + """ + logger.debug("[omni realtime] append audio: %s", len(audio_b64)) self.__send_str( - json.dumps({ - 'event_id': self._generate_event_id(), - 'type': 'input_audio_buffer.append', - 'audio': audio_b64 - }), False) + json.dumps( + { + "event_id": self._generate_event_id(), + "type": "input_audio_buffer.append", + "audio": audio_b64, + }, + ), + False, + ) def append_video(self, video_b64: str) -> None: - ''' + """ send one image frame in video in base64 format Parameters ---------- video_b64: str base64 image string - ''' - logger.debug('[omni realtime] append video: {}'.format(len(video_b64))) + """ + logger.debug("[omni realtime] append video: %s", len(video_b64)) self.__send_str( - json.dumps({ - 'event_id': self._generate_event_id(), - 'type': 'input_image_buffer.append', - 'image': video_b64 - }), False) - - def commit(self, ) -> None: - ''' + json.dumps( + { + "event_id": self._generate_event_id(), + "type": "input_image_buffer.append", + "image": video_b64, + }, + ), + False, + ) + + def commit(self) -> None: + """ Commit the audio and video sent before. When in Server VAD mode, the client does not need to use this method, the server will commit the audio automatically after detecting vad end. - ''' + """ self.__send_str( - json.dumps({ - 'event_id': self._generate_event_id(), - 'type': 'input_audio_buffer.commit' - })) + json.dumps( + { + "event_id": self._generate_event_id(), + "type": "input_audio_buffer.commit", + }, + ), + ) - def clear_appended_audio(self, ) -> None: - ''' + def clear_appended_audio(self) -> None: + """ clear the audio sent to server before. - ''' + """ self.__send_str( - json.dumps({ - 'event_id': self._generate_event_id(), - 'type': 'input_audio_buffer.clear' - })) - - def create_response(self, - instructions: str = None, - output_modalities: List[MultiModality] = None) -> None: - ''' + json.dumps( + { + "event_id": self._generate_event_id(), + "type": "input_audio_buffer.clear", + }, + ), + ) + + def create_response( + self, + instructions: str = None, + output_modalities: List[MultiModality] = None, + ) -> None: + """ create response, use audio and video commited before to request llm. When in Server VAD mode, the client does not need to use this method, the server will create response automatically after detecting vad @@ -412,110 +461,139 @@ def create_response(self, instructions to llm output_modalities: list[MultiModality] omni output modalities to be used in session - ''' + """ request = { - 'event_id': self._generate_event_id(), - 'type': 'response.create', - 'response': {} + "event_id": self._generate_event_id(), + "type": "response.create", + "response": {}, } - request['response']['instructions'] = instructions + request["response"]["instructions"] = instructions if output_modalities: - request['response']['modalities'] = [ + request["response"]["modalities"] = [ m.value for m in output_modalities ] self.__send_str(json.dumps(request)) - def cancel_response(self, ) -> None: - ''' + def cancel_response(self) -> None: + """ cancel the current response - ''' + """ self.__send_str( - json.dumps({ - 'event_id': self._generate_event_id(), - 'type': 'response.cancel' - })) + json.dumps( + { + "event_id": self._generate_event_id(), + "type": "response.cancel", + }, + ), + ) def send_raw(self, raw_data: str) -> None: - ''' + """ send raw data to server - ''' + """ self.__send_str(raw_data) - def close(self, ) -> None: - ''' + def close(self) -> None: + """ close the connection to server - ''' + """ self.ws.close() # 监听消息的回调函数 - def on_message(self, ws, message): + def on_message( # pylint: disable=unused-argument,too-many-branches + self, + ws, + message, + ): if isinstance(message, str): - logger.debug('[omni realtime] receive string {}'.format( - message[:1024])) + logger.debug( + "[omni realtime] receive string %s", + message[:1024], + ) try: # 尝试将消息解析为JSON json_data = json.loads(message) self.last_message = json_data self.callback.on_event(json_data) - if 'type' in message: - if 'session.created' == json_data['type']: - logger.info('[omni realtime] session created') - self.session_id = json_data['session']['id'] - elif 'session.finished' == json_data['type']: + if "type" in message: + if "session.created" == json_data["type"]: + logger.info("[omni realtime] session created") + self.session_id = json_data["session"]["id"] + elif "session.finished" == json_data["type"]: # wait for the event to be set - logger.info('[omni realtime] session finished') + logger.info("[omni realtime] session finished") if self.disconnect_event is not None: self.disconnect_event.set() - if 'response.created' == json_data['type']: - self.last_response_id = json_data['response']['id'] + if "response.created" == json_data["type"]: + self.last_response_id = json_data["response"]["id"] self.last_response_create_time = time.time() * 1000 self.last_first_audio_delay = None self.last_first_text_delay = None - elif 'response.audio_transcript.delta' == json_data[ - 'type']: - if self.last_response_create_time and self.last_first_text_delay is None: - self.last_first_text_delay = time.time( - ) * 1000 - self.last_response_create_time - elif 'response.audio.delta' == json_data['type']: - if self.last_response_create_time and self.last_first_audio_delay is None: - self.last_first_audio_delay = time.time( - ) * 1000 - self.last_response_create_time - elif 'response.done' == json_data['type']: + elif ( + "response.audio_transcript.delta" == json_data["type"] + ): + if ( + self.last_response_create_time + and self.last_first_text_delay is None + ): + self.last_first_text_delay = ( + time.time() * 1000 + - self.last_response_create_time + ) + elif "response.audio.delta" == json_data["type"]: + if ( + self.last_response_create_time + and self.last_first_audio_delay is None + ): + self.last_first_audio_delay = ( + time.time() * 1000 + - self.last_response_create_time + ) + elif "response.done" == json_data["type"]: + # pylint: disable=line-too-long logger.info( - '[Metric] response: {}, first text delay: {}, first audio delay: {}' - .format(self.last_response_id, - self.last_first_text_delay, - self.last_first_audio_delay)) + "[Metric] response: %s, first text delay: %s, first audio delay: %s", # noqa: E501 + self.last_response_id, + self.last_first_text_delay, + self.last_first_audio_delay, + ) except json.JSONDecodeError: - logger.error('Failed to parse message as JSON.') - raise Exception('Failed to parse message as JSON.') + logger.error("Failed to parse message as JSON.") + # pylint: disable=broad-exception-raised,raise-missing-from + raise Exception("Failed to parse message as JSON.") elif isinstance(message, (bytes, bytearray)): # 如果失败,认为是二进制消息 logger.error( - 'should not receive binary message in omni realtime api') - logger.debug('[omni realtime] receive binary {} bytes'.format( - len(message))) + "should not receive binary message in omni realtime api", + ) + logger.debug( + "[omni realtime] receive binary %s bytes", + len(message), + ) - def on_close(self, ws, close_status_code, close_msg): + def on_close( # pylint: disable=unused-argument + self, + ws, + close_status_code, + close_msg, + ): self.callback.on_close(close_status_code, close_msg) # WebSocket发生错误的回调函数 - def on_error(self, ws, error): - print(f'websocket closed due to {error}') - raise Exception(f'websocket closed due to {error}') + def on_error(self, ws, error): # pylint: disable=unused-argument + print(f"websocket closed due to {error}") + # pylint: disable=broad-exception-raised + raise Exception(f"websocket closed due to {error}") # 获取上一个任务的taskId def get_session_id(self) -> str: - return self.session_id - - def get_last_message(self) -> str: - return self.last_message + return self.session_id # type: ignore[return-value] def get_last_message(self) -> str: - return self.last_message + return self.last_message # type: ignore[return-value] def get_last_response_id(self) -> str: - return self.last_response_id + return self.last_response_id # type: ignore[return-value] def get_last_first_text_delay(self): return self.last_first_text_delay diff --git a/dashscope/audio/qwen_tts/__init__.py b/dashscope/audio/qwen_tts/__init__.py index fcb2094..aae6cbd 100644 --- a/dashscope/audio/qwen_tts/__init__.py +++ b/dashscope/audio/qwen_tts/__init__.py @@ -1,5 +1,6 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from .speech_synthesizer import SpeechSynthesizer -__all__ = [SpeechSynthesizer] +__all__ = ["SpeechSynthesizer"] diff --git a/dashscope/audio/qwen_tts/speech_synthesizer.py b/dashscope/audio/qwen_tts/speech_synthesizer.py index 35664bc..f15f7c8 100644 --- a/dashscope/audio/qwen_tts/speech_synthesizer.py +++ b/dashscope/audio/qwen_tts/speech_synthesizer.py @@ -1,34 +1,39 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Generator, Union -from dashscope.api_entities.dashscope_response import \ - TextToSpeechResponse +from dashscope.api_entities.dashscope_response import TextToSpeechResponse from dashscope.client.base_api import BaseApi from dashscope.common.error import InputRequired, ModelRequired class SpeechSynthesizer(BaseApi): - """Text-to-speech interface. - """ + """Text-to-speech interface.""" - task_group = 'aigc' - task = 'multimodal-generation' - function = 'generation' + task_group = "aigc" + task = "multimodal-generation" + function = "generation" class Models: - qwen_tts = 'qwen-tts' + qwen_tts = "qwen-tts" @classmethod - def call( + def call( # type: ignore[override] cls, model: str, text: str, api_key: str = None, workspace: str = None, - **kwargs - ) -> Union[TextToSpeechResponse, Generator[ - TextToSpeechResponse, None, None]]: + **kwargs, + ) -> Union[ + TextToSpeechResponse, + Generator[ + TextToSpeechResponse, + None, + None, + ], + ]: """Call the conversation model service. Args: @@ -36,11 +41,11 @@ def call( text (str): Text content used for speech synthesis. api_key (str, optional): The api api_key, can be None, if None, will retrieve by rule [1]. - [1]: https://help.aliyun.com/zh/dashscope/developer-reference/api-key-settings. # noqa E501 + [1]: https://help.aliyun.com/zh/dashscope/developer-reference/api-key-settings. # noqa E501 # pylint: disable=line-too-long workspace (str): The dashscope workspace id. **kwargs: stream(bool, `optional`): Enable server-sent events - (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 + (ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501 # pylint: disable=line-too-long the result will back partially[qwen-turbo,bailian-v1]. voice: str Voice name. @@ -55,23 +60,26 @@ def call( stream is True, return Generator, otherwise TextToSpeechResponse. """ if not text: - raise InputRequired('text is required!') + raise InputRequired("text is required!") if model is None or not model: - raise ModelRequired('Model is required!') - input = {'text': text} - if 'voice' in kwargs: - input['voice'] = kwargs.pop('voice') - response = super().call(model=model, - task_group=SpeechSynthesizer.task_group, - task=SpeechSynthesizer.task, - function=SpeechSynthesizer.function, - api_key=api_key, - input=input, - workspace=workspace, - **kwargs) - is_stream = kwargs.get('stream', False) + raise ModelRequired("Model is required!") + input = {"text": text} # pylint: disable=redefined-builtin + if "voice" in kwargs: + input["voice"] = kwargs.pop("voice") + response = super().call( + model=model, + task_group=SpeechSynthesizer.task_group, + task=SpeechSynthesizer.task, + function=SpeechSynthesizer.function, + api_key=api_key, + input=input, + workspace=workspace, + **kwargs, + ) + is_stream = kwargs.get("stream", False) if is_stream: - return (TextToSpeechResponse.from_api_response(rsp) - for rsp in response) + return ( + TextToSpeechResponse.from_api_response(rsp) for rsp in response + ) else: return TextToSpeechResponse.from_api_response(response) diff --git a/dashscope/audio/qwen_tts_realtime/__init__.py b/dashscope/audio/qwen_tts_realtime/__init__.py index 1dd1a19..953753c 100644 --- a/dashscope/audio/qwen_tts_realtime/__init__.py +++ b/dashscope/audio/qwen_tts_realtime/__init__.py @@ -1,10 +1,14 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -from .qwen_tts_realtime import (AudioFormat, QwenTtsRealtimeCallback, - QwenTtsRealtime) +from .qwen_tts_realtime import ( + AudioFormat, + QwenTtsRealtimeCallback, + QwenTtsRealtime, +) __all__ = [ - 'AudioFormat', - 'QwenTtsRealtimeCallback', - 'QwenTtsRealtime', + "AudioFormat", + "QwenTtsRealtimeCallback", + "QwenTtsRealtime", ] diff --git a/dashscope/audio/qwen_tts_realtime/qwen_tts_realtime.py b/dashscope/audio/qwen_tts_realtime/qwen_tts_realtime.py index 165d40c..0bb9dcd 100644 --- a/dashscope/audio/qwen_tts_realtime/qwen_tts_realtime.py +++ b/dashscope/audio/qwen_tts_realtime/qwen_tts_realtime.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -7,9 +8,10 @@ import uuid from enum import Enum, unique +import websocket # pylint: disable=wrong-import-order + import dashscope -import websocket -from dashscope.common.error import InputRequired, ModelRequired +from dashscope.common.error import ModelRequired from dashscope.common.logging import logger @@ -18,6 +20,7 @@ class QwenTtsRealtimeCallback: An interface that defines callback methods for getting omni-realtime results. # noqa E501 Derive from this class and implement its function to provide your own data. """ + def on_open(self) -> None: pass @@ -31,9 +34,16 @@ def on_event(self, message: str) -> None: @unique class AudioFormat(Enum): # format, sample_rate, channels, bit_rate, name - PCM_24000HZ_MONO_16BIT = ('pcm', 24000, 'mono', '16bit', 'pcm16') + PCM_24000HZ_MONO_16BIT = ("pcm", 24000, "mono", "16bit", "pcm16") - def __init__(self, format, sample_rate, channels, bit_rate, format_str): + def __init__( # pylint: disable=redefined-builtin + self, + format, + sample_rate, + channels, + bit_rate, + format_str, + ): self.format = format self.sample_rate = sample_rate self.channels = channels @@ -44,7 +54,7 @@ def __repr__(self): return self.format_str def __str__(self): - return f'{self.format.upper()} with {self.sample_rate}Hz sample rate, {self.channels} channel, {self.bit_rate} bit rate: {self.format_str}' + return f"{self.format.upper()} with {self.sample_rate}Hz sample rate, {self.channels} channel, {self.bit_rate} bit rate: {self.format_str}" # noqa: E501 # pylint: disable=line-too-long class QwenTtsRealtime: @@ -55,7 +65,7 @@ def __init__( callback: QwenTtsRealtimeCallback = None, workspace=None, url=None, - additional_params=None, + additional_params=None, # pylint: disable=unused-argument ): """ Qwen Tts Realtime SDK @@ -76,11 +86,11 @@ def __init__( """ if model is None: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") if url is None: - url = f'wss://dashscope.aliyuncs.com/api-ws/v1/realtime?model={model}' + url = f"wss://dashscope.aliyuncs.com/api-ws/v1/realtime?model={model}" # noqa: E501 else: - url = f'{url}?model={model}' + url = f"{url}?model={model}" self.url = url self.apikey = dashscope.api_key self.user_headers = headers @@ -97,35 +107,34 @@ def __init__( self.metrics = [] def _generate_event_id(self): - ''' + """ generate random event id: event_xxxx - ''' - return 'event_' + uuid.uuid4().hex - - def _get_websocket_header(self, ): - ua = 'dashscope/%s; python/%s; platform/%s; processor/%s' % ( - '1.18.0', # dashscope version - platform.python_version(), - platform.platform(), - platform.processor(), + """ + return "event_" + uuid.uuid4().hex + + def _get_websocket_header(self): + ua = ( + f"dashscope/1.18.0; python/{platform.python_version()}; " + f"platform/{platform.platform()}; " + f"processor/{platform.processor()}" ) headers = { - 'user-agent': ua, - 'Authorization': 'bearer ' + self.apikey, + "user-agent": ua, + "Authorization": "bearer " + self.apikey, } if self.user_headers: headers = {**self.user_headers, **headers} if self.user_workspace: headers = { **headers, - 'X-DashScope-WorkSpace': self.user_workspace, + "X-DashScope-WorkSpace": self.user_workspace, } return headers def connect(self) -> None: - ''' - connect to server, create session and return default session configuration - ''' + """ + connect to server, create session and return default session configuration # noqa: E501 + """ self.ws = websocket.WebSocketApp( self.url, header=self._get_websocket_header(), @@ -138,36 +147,39 @@ def connect(self) -> None: self.thread.start() timeout = 5 # 最长等待时间(秒) start_time = time.time() - while (not (self.ws.sock and self.ws.sock.connected) - and (time.time() - start_time) < timeout): + while ( + not (self.ws.sock and self.ws.sock.connected) + and (time.time() - start_time) < timeout + ): time.sleep(0.1) # 短暂休眠,避免密集轮询 if not (self.ws.sock and self.ws.sock.connected): raise TimeoutError( - 'websocket connection could not established within 5s. ' - 'Please check your network connection, firewall settings, or server status.' + "websocket connection could not established within 5s. " + "Please check your network connection, firewall settings, or server status.", # noqa: E501 # pylint: disable=line-too-long ) self.callback.on_open() def __send_str(self, data: str, enable_log: bool = True): if enable_log: - logger.debug('[qwen tts realtime] send string: {}'.format(data)) + logger.debug("[qwen tts realtime] send string: %s", data) self.ws.send(data) - def update_session(self, - voice: str, - response_format: AudioFormat = AudioFormat. - PCM_24000HZ_MONO_16BIT, - mode: str = 'server_commit', - sample_rate: int = None, - volume: int = None, - speech_rate: float = None, - audio_format: str = None, - pitch_rate: float = None, - bit_rate: int = None, - language_type: str = None, - enable_tn: bool = None, - **kwargs) -> None: - ''' + def update_session( + self, + voice: str, + response_format: AudioFormat = AudioFormat.PCM_24000HZ_MONO_16BIT, + mode: str = "server_commit", + sample_rate: int = None, + volume: int = None, + speech_rate: float = None, + audio_format: str = None, + pitch_rate: float = None, + bit_rate: int = None, + language_type: str = None, + enable_tn: bool = None, + **kwargs, + ) -> None: + """ update session configuration, should be used before create response Parameters @@ -181,7 +193,7 @@ def update_session(self, language_type: str language type for synthesized audio, default is 'auto' sample_rate: int - sampleRate for tts, range [8000,16000,22050,24000,44100,48000] default is 24000 + sampleRate for tts, range [8000,16000,22050,24000,44100,48000] default is 24000 # noqa: E501 # pylint: disable=line-too-long volume: int volume for tts, range [0,100] default is 50 speech_rate: float @@ -191,155 +203,196 @@ def update_session(self, pitch_rate: float pitch_rate for tts, range [0.5~2.0] default is 1.0 bit_rate: int - bit_rate for tts, support 6~510,default is 128kbps. only work on format: opus/mp3 + bit_rate for tts, support 6~510,default is 128kbps. only work on format: opus/mp3 # noqa: E501 # pylint: disable=line-too-long enable_tn: bool enable text normalization for tts, default is None - ''' + """ self.config = { - 'voice': voice, - 'mode': mode, - 'response_format': response_format.format, - 'sample_rate': response_format.sample_rate, + "voice": voice, + "mode": mode, + "response_format": response_format.format, + "sample_rate": response_format.sample_rate, } if sample_rate is not None: # 如果配置,则更新 - self.config['sample_rate'] = sample_rate + self.config["sample_rate"] = sample_rate if volume is not None: - self.config['volume'] = volume + self.config["volume"] = volume if speech_rate is not None: - self.config['speech_rate'] = speech_rate + self.config["speech_rate"] = speech_rate if audio_format is not None: - self.config['response_format'] = audio_format # 如果配置,则更新 + self.config["response_format"] = audio_format # 如果配置,则更新 if pitch_rate is not None: - self.config['pitch_rate'] = pitch_rate + self.config["pitch_rate"] = pitch_rate if bit_rate is not None: - self.config['bit_rate'] = bit_rate + self.config["bit_rate"] = bit_rate if enable_tn is not None: - self.config['enable_tn'] = enable_tn + self.config["enable_tn"] = enable_tn if language_type is not None: - self.config['language_type'] = language_type + self.config["language_type"] = language_type self.config.update(kwargs) self.__send_str( - json.dumps({ - 'event_id': self._generate_event_id(), - 'type': 'session.update', - 'session': self.config - })) + json.dumps( + { + "event_id": self._generate_event_id(), + "type": "session.update", + "session": self.config, + }, + ), + ) def append_text(self, text: str) -> None: - ''' + """ send text Parameters ---------- text: str text to send - ''' + """ self.__send_str( - json.dumps({ - 'event_id': self._generate_event_id(), - 'type': 'input_text_buffer.append', - 'text': text - })) + json.dumps( + { + "event_id": self._generate_event_id(), + "type": "input_text_buffer.append", + "text": text, + }, + ), + ) if self.last_first_text_time is None: self.last_first_text_time = time.time() * 1000 - def commit(self, ) -> None: - ''' + def commit(self) -> None: + """ commit the text sent before, create response and start synthesis audio. - ''' + """ self.__send_str( - json.dumps({ - 'event_id': self._generate_event_id(), - 'type': 'input_text_buffer.commit' - })) + json.dumps( + { + "event_id": self._generate_event_id(), + "type": "input_text_buffer.commit", + }, + ), + ) - def clear_appended_text(self, ) -> None: - ''' + def clear_appended_text(self) -> None: + """ clear the text sent to server before. - ''' + """ self.__send_str( - json.dumps({ - 'event_id': self._generate_event_id(), - 'type': 'input_text_buffer.clear' - })) + json.dumps( + { + "event_id": self._generate_event_id(), + "type": "input_text_buffer.clear", + }, + ), + ) - def cancel_response(self, ) -> None: - ''' + def cancel_response(self) -> None: + """ cancel the current response - ''' + """ self.__send_str( - json.dumps({ - 'event_id': self._generate_event_id(), - 'type': 'response.cancel' - })) + json.dumps( + { + "event_id": self._generate_event_id(), + "type": "response.cancel", + }, + ), + ) def send_raw(self, raw_data: str) -> None: - ''' + """ send raw data to server - ''' + """ self.__send_str(raw_data) - def finish(self, ) -> None: - ''' - finish input text stream, server will synthesis all text in buffer and close the connection - ''' + def finish(self) -> None: + """ + finish input text stream, server will synthesis all text in buffer and close the connection # noqa: E501 # pylint: disable=line-too-long + """ self.__send_str( - json.dumps({ - 'event_id': self._generate_event_id(), - 'type': 'session.finish' - })) + json.dumps( + { + "event_id": self._generate_event_id(), + "type": "session.finish", + }, + ), + ) - def close(self, ) -> None: - ''' + def close(self) -> None: + """ close the connection to server - ''' + """ self.ws.close() # 监听消息的回调函数 - def on_message(self, ws, message): + def on_message( # pylint: disable=unused-argument + self, + ws, + message, + ): if isinstance(message, str): - logger.debug('[omni realtime] receive string {}'.format( - message[:1024])) + logger.debug( + "[omni realtime] receive string %s", + message[:1024], + ) try: # 尝试将消息解析为JSON json_data = json.loads(message) self.last_message = json_data self.callback.on_event(json_data) - if 'type' in message: - if 'session.created' == json_data['type']: - self.session_id = json_data['session']['id'] - if 'response.created' == json_data['type']: - self.last_response_id = json_data['response']['id'] - elif 'response.audio.delta' == json_data['type']: - if self.last_first_text_time and self.last_first_audio_delay is None: - self.last_first_audio_delay = time.time( - ) * 1000 - self.last_first_text_time - elif 'response.done' == json_data['type']: + if "type" in message: + if "session.created" == json_data["type"]: + self.session_id = json_data["session"]["id"] + if "response.created" == json_data["type"]: + self.last_response_id = json_data["response"]["id"] + elif "response.audio.delta" == json_data["type"]: + if ( + self.last_first_text_time + and self.last_first_audio_delay is None + ): + self.last_first_audio_delay = ( + time.time() * 1000 - self.last_first_text_time + ) + elif "response.done" == json_data["type"]: logger.debug( - '[Metric] response: {}, first audio delay: {}' - .format(self.last_response_id, - self.last_first_audio_delay)) + "[Metric] response: %s, first audio delay: %s", # noqa: E501 + self.last_response_id, + self.last_first_audio_delay, + ) except json.JSONDecodeError: - logger.error('Failed to parse message as JSON.') - raise Exception('Failed to parse message as JSON.') + logger.error("Failed to parse message as JSON.") + # pylint: disable=broad-exception-raised,raise-missing-from + raise Exception("Failed to parse message as JSON.") elif isinstance(message, (bytes, bytearray)): # 如果失败,认为是二进制消息 logger.error( - 'should not receive binary message in omni realtime api') - logger.debug('[omni realtime] receive binary {} bytes'.format( - len(message))) + "should not receive binary message in omni realtime api", + ) + logger.debug( + "[omni realtime] receive binary %s bytes", + len(message), + ) - def on_close(self, ws, close_status_code, close_msg): + def on_close( # pylint: disable=unused-argument + self, + ws, + close_status_code, + close_msg, + ): logger.debug( - '[omni realtime] connection closed with code {} and message {}'.format( - close_status_code, close_msg)) + "[omni realtime] connection closed with code %s and message %s", # noqa: E501 + close_status_code, + close_msg, + ) self.callback.on_close(close_status_code, close_msg) # WebSocket发生错误的回调函数 - def on_error(self, ws, error): - print(f'websocket closed due to {error}') - raise Exception(f'websocket closed due to {error}') + def on_error(self, ws, error): # pylint: disable=unused-argument + print(f"websocket closed due to {error}") + # pylint: disable=broad-exception-raised + raise Exception(f"websocket closed due to {error}") # 获取上一个任务的taskId def get_session_id(self): diff --git a/dashscope/audio/tts/__init__.py b/dashscope/audio/tts/__init__.py index 85055fb..c1a364c 100644 --- a/dashscope/audio/tts/__init__.py +++ b/dashscope/audio/tts/__init__.py @@ -1,6 +1,10 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -from .speech_synthesizer import (ResultCallback, SpeechSynthesisResult, - SpeechSynthesizer) +from .speech_synthesizer import ( + ResultCallback, + SpeechSynthesisResult, + SpeechSynthesizer, +) -__all__ = [SpeechSynthesizer, ResultCallback, SpeechSynthesisResult] +__all__ = ["SpeechSynthesizer", "ResultCallback", "SpeechSynthesisResult"] diff --git a/dashscope/audio/tts/speech_synthesizer.py b/dashscope/audio/tts/speech_synthesizer.py index 80c89cc..0381abe 100644 --- a/dashscope/audio/tts/speech_synthesizer.py +++ b/dashscope/audio/tts/speech_synthesizer.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from http import HTTPStatus @@ -9,9 +10,9 @@ from dashscope.common.utils import _get_task_group_and_task -class SpeechSynthesisResult(): +class SpeechSynthesisResult: """The result set of speech synthesis, including audio data, - timestamp information, and final result information. + timestamp information, and final result information. """ _audio_frame: bytes = None @@ -20,9 +21,14 @@ class SpeechSynthesisResult(): _sentences: List[Dict[str, str]] = None _response: SpeechSynthesisResponse = None - def __init__(self, frame: bytes, data: bytes, sentence: Dict[str, str], - sentences: List[Dict[str, str]], - response: SpeechSynthesisResponse): + def __init__( + self, + frame: bytes, + data: bytes, + sentence: Dict[str, str], + sentences: List[Dict[str, str]], + response: SpeechSynthesisResponse, + ): if frame is not None: self._audio_frame = bytes(frame) if data is not None: @@ -35,13 +41,11 @@ def __init__(self, frame: bytes, data: bytes, sentence: Dict[str, str], self._response = response def get_audio_frame(self) -> bytes: - """Obtain the audio frame data of speech synthesis through callbacks. - """ + """Obtain the audio frame data of speech synthesis through callbacks.""" # noqa: E501 return self._audio_frame def get_audio_data(self) -> bytes: - """Get complete audio data for speech synthesis. - """ + """Get complete audio data for speech synthesis.""" return self._audio_data def get_timestamp(self) -> Dict[str, str]: @@ -51,8 +55,7 @@ def get_timestamp(self) -> Dict[str, str]: return self._sentence def get_timestamps(self) -> List[Dict[str, str]]: - """Get complete timestamp information for all speech synthesis sentences. - """ + """Get complete timestamp information for all speech synthesis sentences.""" # noqa: E501 return self._sentences def get_response(self) -> SpeechSynthesisResponse: @@ -62,11 +65,12 @@ def get_response(self) -> SpeechSynthesisResponse: return self._response -class ResultCallback(): +class ResultCallback: """ An interface that defines callback methods for getting speech synthesis results. # noqa E501 Derive from this class and implement its function to provide your own data. """ + def on_open(self) -> None: pass @@ -84,20 +88,23 @@ def on_event(self, result: SpeechSynthesisResult) -> None: class SpeechSynthesizer(BaseApi): - """Text-to-speech interface. - """ + """Text-to-speech interface.""" + class AudioFormat: - format_wav = 'wav' - format_pcm = 'pcm' - format_mp3 = 'mp3' + format_wav = "wav" + format_pcm = "pcm" + format_mp3 = "mp3" @classmethod - def call(cls, - model: str, - text: str, - callback: ResultCallback = None, - workspace: str = None, - **kwargs) -> SpeechSynthesisResult: + # type: ignore[override] + def call( # pylint: disable=R1702,too-many-branches # type: ignore[override] # noqa: E501 + cls, + model: str, + text: str, + callback: ResultCallback = None, + workspace: str = None, + **kwargs, + ) -> SpeechSynthesisResult: """Convert text to speech synchronously. Args: @@ -132,15 +139,17 @@ def call(cls, _task_failed_flag: bool = False task_name, _ = _get_task_group_and_task(__name__) - response = super().call(model=model, - task_group='audio', - task=task_name, - function='SpeechSynthesizer', - input={'text': text}, - stream=True, - api_protocol=ApiProtocol.WEBSOCKET, - workspace=workspace, - **kwargs) + response = super().call( + model=model, + task_group="audio", + task=task_name, + function="SpeechSynthesizer", + input={"text": text}, + stream=True, + api_protocol=ApiProtocol.WEBSOCKET, + workspace=workspace, + **kwargs, + ) if _callback is not None: _callback.on_open() @@ -149,7 +158,12 @@ def call(cls, if isinstance(part.output, bytes): if _callback is not None: audio_frame = SpeechSynthesisResult( - bytes(part.output), None, None, None, None) + bytes(part.output), + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] + ) _callback.on_event(audio_frame) if _audio_data is None: @@ -160,38 +174,56 @@ def call(cls, else: if part.status_code == HTTPStatus.OK: if part.output is None: - _the_final_response = SpeechSynthesisResponse.from_api_response( # noqa E501 - part) + _the_final_response = SpeechSynthesisResponse.from_api_response( # noqa E501 # pylint: disable=line-too-long + part, + ) else: if _callback is not None: sentence = SpeechSynthesisResult( - None, None, part.output['sentence'], None, - None) + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] + part.output["sentence"], + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] + ) _callback.on_event(sentence) if len(_sentences) == 0: - _sentences.append(part.output['sentence']) + _sentences.append(part.output["sentence"]) else: - if _sentences[-1]['begin_time'] == part.output[ - 'sentence']['begin_time']: - if _sentences[-1]['end_time'] != part.output[ - 'sentence']['end_time']: + if ( + _sentences[-1]["begin_time"] + == part.output["sentence"]["begin_time"] + ): + if ( + _sentences[-1]["end_time"] + != part.output["sentence"]["end_time"] + ): _sentences.pop() - _sentences.append(part.output['sentence']) + _sentences.append(part.output["sentence"]) else: - _sentences.append(part.output['sentence']) + _sentences.append(part.output["sentence"]) else: _task_failed_flag = True - _the_final_response = SpeechSynthesisResponse.from_api_response( # noqa E501 - part) + _the_final_response = ( + SpeechSynthesisResponse.from_api_response( # noqa E501 + part, + ) + ) if _callback is not None: _callback.on_error( - SpeechSynthesisResponse.from_api_response(part)) + SpeechSynthesisResponse.from_api_response(part), + ) if _callback is not None: if _task_failed_flag is False: _callback.on_complete() _callback.on_close() - result = SpeechSynthesisResult(None, _audio_data, None, _sentences, - _the_final_response) + result = SpeechSynthesisResult( + None, # type: ignore[arg-type] + _audio_data, + None, # type: ignore[arg-type] + _sentences, + _the_final_response, # type: ignore[arg-type] + ) return result diff --git a/dashscope/audio/tts_v2/__init__.py b/dashscope/audio/tts_v2/__init__.py index 45dc191..9c0cc5d 100644 --- a/dashscope/audio/tts_v2/__init__.py +++ b/dashscope/audio/tts_v2/__init__.py @@ -1,12 +1,19 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from .enrollment import VoiceEnrollmentException, VoiceEnrollmentService -from .speech_synthesizer import (AudioFormat, ResultCallback, - SpeechSynthesizer, - SpeechSynthesizerObjectPool) +from .speech_synthesizer import ( + AudioFormat, + ResultCallback, + SpeechSynthesizer, + SpeechSynthesizerObjectPool, +) __all__ = [ - 'SpeechSynthesizer', 'ResultCallback', 'AudioFormat', - 'VoiceEnrollmentException', 'VoiceEnrollmentService', - 'SpeechSynthesizerObjectPool' + "SpeechSynthesizer", + "ResultCallback", + "AudioFormat", + "VoiceEnrollmentException", + "VoiceEnrollmentService", + "SpeechSynthesizerObjectPool", ] diff --git a/dashscope/audio/tts_v2/enrollment.py b/dashscope/audio/tts_v2/enrollment.py index a254c49..35f2c54 100644 --- a/dashscope/audio/tts_v2/enrollment.py +++ b/dashscope/audio/tts_v2/enrollment.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio @@ -12,28 +13,36 @@ class VoiceEnrollmentException(Exception): - def __init__(self, request_id: str, status_code: int, code: str, - error_message: str) -> None: + def __init__( + self, + request_id: str, + status_code: int, + code: str, + error_message: str, + ) -> None: self._request_id = request_id self._status_code = status_code self._code = code self._error_message = error_message def __str__(self): - return f'Request: {self._request_id}, Status Code: {self._status_code}, Code: {self._code}, Error Message: {self._error_message}' + return f"Request: {self._request_id}, Status Code: {self._status_code}, Code: {self._code}, Error Message: {self._error_message}" # noqa: E501 # pylint: disable=line-too-long class VoiceEnrollmentService(BaseApi): - ''' + """ API for voice clone service - ''' + """ + MAX_QUERY_TRY_COUNT = 3 - def __init__(self, - api_key=None, - workspace=None, - model=None, - **kwargs) -> None: + def __init__( + self, + api_key=None, + workspace=None, + model=None, + **kwargs, + ) -> None: super().__init__() self._api_key = api_key self._workspace = workspace @@ -41,22 +50,27 @@ def __init__(self, self._last_request_id = None self.model = model if self.model is None: - self.model = 'voice-enrollment' + self.model = "voice-enrollment" - def __call_with_input(self, input): + def __call_with_input( # pylint: disable=redefined-builtin + self, + input, + ): try_count = 0 while True: try: - response = super().call(model=self.model, - task_group='audio', - task='tts', - function='customization', - input=input, - api_protocol=ApiProtocol.HTTP, - http_method=HTTPMethod.POST, - api_key=self._api_key, - workspace=self._workspace, - **self._kwargs) + response = super().call( + model=self.model, + task_group="audio", + task="tts", + function="customization", + input=input, + api_protocol=ApiProtocol.HTTP, + http_method=HTTPMethod.POST, + api_key=self._api_key, + workspace=self._workspace, + **self._kwargs, + ) except (asyncio.TimeoutError, aiohttp.ClientConnectorError) as e: logger.error(e) try_count += 1 @@ -65,115 +79,158 @@ def __call_with_input(self, input): continue break - logger.debug('>>>>recv', response) + logger.debug(">>>>recv %s", response) return response - def create_voice(self, target_model: str, prefix: str, url: str, language_hints: List[str] = None) -> str: - ''' + def create_voice( + self, + target_model: str, + prefix: str, + url: str, + language_hints: List[str] = None, + ) -> str: + """ 创建新克隆音色 param: target_model 克隆音色对应的语音合成模型版本 param: prefix 音色自定义前缀,仅允许数字和小写字母,小于十个字符。 param: url 用于克隆的音频文件url param: language_hints 克隆音色目标语言 return: voice_id - ''' + """ input_params = { - 'action': 'create_voice', - 'target_model': target_model, - 'prefix': prefix, - 'url': url + "action": "create_voice", + "target_model": target_model, + "prefix": prefix, + "url": url, } if language_hints is not None: - input_params['language_hints'] = language_hints + input_params["language_hints"] = language_hints response = self.__call_with_input(input_params) self._last_request_id = response.request_id if response.status_code == 200: - return response.output['voice_id'] + return response.output["voice_id"] else: - raise VoiceEnrollmentException(response.request_id, response.status_code, response.code, - response.message) - - def list_voices(self, - prefix=None, - page_index: int = 0, - page_size: int = 10) -> List[dict]: - ''' + raise VoiceEnrollmentException( + response.request_id, + response.status_code, + response.code, + response.message, + ) + + def list_voices( + self, + prefix=None, + page_index: int = 0, + page_size: int = 10, + ) -> List[dict]: + """ 查询已创建的所有音色 param: page_index 查询的页索引 param: page_size 查询页大小 return: List[dict] 音色列表,包含每个音色的id,创建时间,修改时间,状态。 - ''' + """ if prefix: - response = self.__call_with_input(input={ - 'action': 'list_voice', - 'prefix': prefix, - 'page_index': page_index, - 'page_size': page_size, - }, ) + # pylint: disable=no-value-for-parameter + response = self.__call_with_input( + input={ + "action": "list_voice", + "prefix": prefix, + "page_index": page_index, + "page_size": page_size, + }, + ) else: - response = self.__call_with_input(input={ - 'action': 'list_voice', - 'page_index': page_index, - 'page_size': page_size, - }, ) + # pylint: disable=no-value-for-parameter + response = self.__call_with_input( + input={ + "action": "list_voice", + "page_index": page_index, + "page_size": page_size, + }, + ) self._last_request_id = response.request_id if response.status_code == 200: - return response.output['voice_list'] + return response.output["voice_list"] else: - raise VoiceEnrollmentException(response.request_id, response.status_code, response.code, - response.message) + raise VoiceEnrollmentException( + response.request_id, + response.status_code, + response.code, + response.message, + ) def query_voice(self, voice_id: str) -> List[str]: - ''' + """ 查询已创建的所有音色 param: voice_id 需要查询的音色 return: bytes 注册音色使用的音频 - ''' - response = self.__call_with_input(input={ - 'action': 'query_voice', - 'voice_id': voice_id, - }, ) + """ + # pylint: disable=no-value-for-parameter + response = self.__call_with_input( + input={ + "action": "query_voice", + "voice_id": voice_id, + }, + ) self._last_request_id = response.request_id if response.status_code == 200: return response.output else: - raise VoiceEnrollmentException(response.request_id, response.status_code, response.code, - response.message) + raise VoiceEnrollmentException( + response.request_id, + response.status_code, + response.code, + response.message, + ) def update_voice(self, voice_id: str, url: str) -> None: - ''' + """ 更新音色 param: voice_id 音色id param: url 用于克隆的音频文件url - ''' - response = self.__call_with_input(input={ - 'action': 'update_voice', - 'voice_id': voice_id, - 'url': url, - }, ) + """ + # pylint: disable=no-value-for-parameter + response = self.__call_with_input( + input={ + "action": "update_voice", + "voice_id": voice_id, + "url": url, + }, + ) self._last_request_id = response.request_id if response.status_code == 200: return else: - raise VoiceEnrollmentException(response.request_id, response.status_code, response.code, - response.message) + raise VoiceEnrollmentException( + response.request_id, + response.status_code, + response.code, + response.message, + ) def delete_voice(self, voice_id: str) -> None: - ''' + """ 删除音色 param: voice_id 需要删除的音色 - ''' - response = self.__call_with_input(input={ - 'action': 'delete_voice', - 'voice_id': voice_id, - }, ) + """ + # pylint: disable=no-value-for-parameter + response = self.__call_with_input( + input={ + "action": "delete_voice", + "voice_id": voice_id, + }, + ) self._last_request_id = response.request_id if response.status_code == 200: return else: - raise VoiceEnrollmentException(response.request_id, response.status_code, response.code, - response.message) + raise VoiceEnrollmentException( + response.request_id, + response.status_code, + response.code, + response.message, + ) def get_last_request_id(self): return self._last_request_id diff --git a/dashscope/audio/tts_v2/speech_synthesizer.py b/dashscope/audio/tts_v2/speech_synthesizer.py index f280c8a..c980e6b 100644 --- a/dashscope/audio/tts_v2/speech_synthesizer.py +++ b/dashscope/audio/tts_v2/speech_synthesizer.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -13,9 +14,15 @@ import dashscope from dashscope.common.error import InputRequired, InvalidTask, ModelRequired from dashscope.common.logging import logger -from dashscope.protocol.websocket import (ACTION_KEY, EVENT_KEY, HEADER, - TASK_ID, ActionType, EventType, - WebsocketStreamingMode) +from dashscope.protocol.websocket import ( + ACTION_KEY, + EVENT_KEY, + HEADER, + TASK_ID, + ActionType, + EventType, + WebsocketStreamingMode, +) class ResultCallback: @@ -23,6 +30,7 @@ class ResultCallback: An interface that defines callback methods for getting speech synthesis results. # noqa E501 Derive from this class and implement its function to provide your own data. """ + def on_open(self) -> None: pass @@ -44,27 +52,27 @@ def on_data(self, data: bytes) -> None: @unique class AudioFormat(Enum): - DEFAULT = ('Default', 0, '0', 0) - WAV_8000HZ_MONO_16BIT = ('wav', 8000, 'mono', 0) - WAV_16000HZ_MONO_16BIT = ('wav', 16000, 'mono', 16) - WAV_22050HZ_MONO_16BIT = ('wav', 22050, 'mono', 16) - WAV_24000HZ_MONO_16BIT = ('wav', 24000, 'mono', 16) - WAV_44100HZ_MONO_16BIT = ('wav', 44100, 'mono', 16) - WAV_48000HZ_MONO_16BIT = ('wav', 48000, 'mono', 16) - - MP3_8000HZ_MONO_128KBPS = ('mp3', 8000, 'mono', 128) - MP3_16000HZ_MONO_128KBPS = ('mp3', 16000, 'mono', 128) - MP3_22050HZ_MONO_256KBPS = ('mp3', 22050, 'mono', 256) - MP3_24000HZ_MONO_256KBPS = ('mp3', 24000, 'mono', 256) - MP3_44100HZ_MONO_256KBPS = ('mp3', 44100, 'mono', 256) - MP3_48000HZ_MONO_256KBPS = ('mp3', 48000, 'mono', 256) - - PCM_8000HZ_MONO_16BIT = ('pcm', 8000, 'mono', 16) - PCM_16000HZ_MONO_16BIT = ('pcm', 16000, 'mono', 16) - PCM_22050HZ_MONO_16BIT = ('pcm', 22050, 'mono', 16) - PCM_24000HZ_MONO_16BIT = ('pcm', 24000, 'mono', 16) - PCM_44100HZ_MONO_16BIT = ('pcm', 44100, 'mono', 16) - PCM_48000HZ_MONO_16BIT = ('pcm', 48000, 'mono', 16) + DEFAULT = ("Default", 0, "0", 0) + WAV_8000HZ_MONO_16BIT = ("wav", 8000, "mono", 0) + WAV_16000HZ_MONO_16BIT = ("wav", 16000, "mono", 16) + WAV_22050HZ_MONO_16BIT = ("wav", 22050, "mono", 16) + WAV_24000HZ_MONO_16BIT = ("wav", 24000, "mono", 16) + WAV_44100HZ_MONO_16BIT = ("wav", 44100, "mono", 16) + WAV_48000HZ_MONO_16BIT = ("wav", 48000, "mono", 16) + + MP3_8000HZ_MONO_128KBPS = ("mp3", 8000, "mono", 128) + MP3_16000HZ_MONO_128KBPS = ("mp3", 16000, "mono", 128) + MP3_22050HZ_MONO_256KBPS = ("mp3", 22050, "mono", 256) + MP3_24000HZ_MONO_256KBPS = ("mp3", 24000, "mono", 256) + MP3_44100HZ_MONO_256KBPS = ("mp3", 44100, "mono", 256) + MP3_48000HZ_MONO_256KBPS = ("mp3", 48000, "mono", 256) + + PCM_8000HZ_MONO_16BIT = ("pcm", 8000, "mono", 16) + PCM_16000HZ_MONO_16BIT = ("pcm", 16000, "mono", 16) + PCM_22050HZ_MONO_16BIT = ("pcm", 22050, "mono", 16) + PCM_24000HZ_MONO_16BIT = ("pcm", 24000, "mono", 16) + PCM_44100HZ_MONO_16BIT = ("pcm", 44100, "mono", 16) + PCM_48000HZ_MONO_16BIT = ("pcm", 48000, "mono", 16) OGG_OPUS_8KHZ_MONO_32KBPS = ("opus", 8000, "mono", 32) OGG_OPUS_8KHZ_MONO_16KBPS = ("opus", 8000, "mono", 16) @@ -77,23 +85,30 @@ class AudioFormat(Enum): OGG_OPUS_48KHZ_MONO_16KBPS = ("opus", 48000, "mono", 16) OGG_OPUS_48KHZ_MONO_32KBPS = ("opus", 48000, "mono", 32) OGG_OPUS_48KHZ_MONO_64KBPS = ("opus", 48000, "mono", 64) - def __init__(self, format, sample_rate, channels, bit_rate): + + def __init__( # pylint: disable=redefined-builtin + self, + format, + sample_rate, + channels, + bit_rate, + ): self.format = format self.sample_rate = sample_rate self.channels = channels self.bit_rate = bit_rate def __str__(self): - return f'{self.format.upper()} with {self.sample_rate}Hz sample rate, {self.channels} channel, {self.bit_rate}' + return f"{self.format.upper()} with {self.sample_rate}Hz sample rate, {self.channels} channel, {self.bit_rate}" # noqa: E501 # pylint: disable=line-too-long class Request: - def __init__( + def __init__( # pylint: disable=redefined-builtin self, apikey, model, voice, - format='wav', + format="wav", sample_rate=16000, bit_rate=64000, volume=50, @@ -124,60 +139,60 @@ def genUid(self): return uuid.uuid4().hex def getWebsocketHeaders(self, headers, workspace): - ua = 'dashscope/%s; python/%s; platform/%s; processor/%s' % ( - '1.18.0', # dashscope version - platform.python_version(), - platform.platform(), - platform.processor(), + ua = ( + f"dashscope/1.18.0; python/{platform.python_version()}; " + f"platform/{platform.platform()}; " + f"processor/{platform.processor()}" ) self.headers = { - 'user-agent': ua, - 'Authorization': 'bearer ' + self.apikey, + "user-agent": ua, + "Authorization": "bearer " + self.apikey, } if headers: self.headers = {**self.headers, **headers} if workspace: self.headers = { **self.headers, - 'X-DashScope-WorkSpace': workspace, + "X-DashScope-WorkSpace": workspace, } return self.headers def getStartRequest(self, additional_params=None): - cmd = { HEADER: { ACTION_KEY: ActionType.START, TASK_ID: self.task_id, - 'streaming': WebsocketStreamingMode.DUPLEX, + "streaming": WebsocketStreamingMode.DUPLEX, }, - 'payload': { - 'model': self.model, - 'task_group': 'audio', - 'task': 'tts', - 'function': 'SpeechSynthesizer', - 'input': {}, - 'parameters': { - 'voice': self.voice, - 'volume': self.volume, - 'text_type': 'PlainText', - 'sample_rate': self.sample_rate, - 'rate': self.speech_rate, - 'format': self.format, - 'pitch': self.pitch_rate, - 'seed': self.seed, - 'type': self.synthesis_type + "payload": { + "model": self.model, + "task_group": "audio", + "task": "tts", + "function": "SpeechSynthesizer", + "input": {}, + "parameters": { + "voice": self.voice, + "volume": self.volume, + "text_type": "PlainText", + "sample_rate": self.sample_rate, + "rate": self.speech_rate, + "format": self.format, + "pitch": self.pitch_rate, + "seed": self.seed, + "type": self.synthesis_type, }, }, } - if self.format == 'opus': - cmd['payload']['parameters']['bit_rate'] = self.bit_rate + if self.format == "opus": + cmd["payload"]["parameters"]["bit_rate"] = self.bit_rate if additional_params: - cmd['payload']['parameters'].update(additional_params) + cmd["payload"]["parameters"].update(additional_params) if self.instruction is not None: - cmd['payload']['parameters']['instruction'] = self.instruction + cmd["payload"]["parameters"]["instruction"] = self.instruction if self.language_hints is not None: - cmd['payload']['parameters']['language_hints'] = self.language_hints + cmd["payload"]["parameters"][ + "language_hints" + ] = self.language_hints return json.dumps(cmd) def getContinueRequest(self, text): @@ -185,15 +200,15 @@ def getContinueRequest(self, text): HEADER: { ACTION_KEY: ActionType.CONTINUE, TASK_ID: self.task_id, - 'streaming': WebsocketStreamingMode.DUPLEX, + "streaming": WebsocketStreamingMode.DUPLEX, }, - 'payload': { - 'model': self.model, - 'task_group': 'audio', - 'task': 'tts', - 'function': 'SpeechSynthesizer', - 'input': { - 'text': text + "payload": { + "model": self.model, + "task_group": "audio", + "task": "tts", + "function": "SpeechSynthesizer", + "input": { + "text": text, }, }, } @@ -204,17 +219,17 @@ def getFinishRequest(self): HEADER: { ACTION_KEY: ActionType.FINISHED, TASK_ID: self.task_id, - 'streaming': WebsocketStreamingMode.DUPLEX, + "streaming": WebsocketStreamingMode.DUPLEX, }, - 'payload': { - 'input': {}, + "payload": { + "input": {}, }, } return json.dumps(cmd) class SpeechSynthesizer: - def __init__( + def __init__( # pylint: disable=redefined-builtin self, model, voice, @@ -245,9 +260,9 @@ def __init__( volume: int The volume of the synthesized audio, with a range from 0 to 100. Default is 50. rate: float - The speech rate of the synthesized audio, with a range from 0.5 to 2. Default is 1.0. + The speech rate of the synthesized audio, with a range from 0.5 to 2. Default is 1.0. # noqa: E501 # pylint: disable=line-too-long pitch: float - The pitch of the synthesized audio, with a range from 0.5 to 2. Default is 1.0. + The pitch of the synthesized audio, with a range from 0.5 to 2. Default is 1.0. # noqa: E501 # pylint: disable=line-too-long headers: Dict User-defined headers. callback: ResultCallback @@ -284,29 +299,45 @@ def __init__( self._recv_audio_length = 0 self.last_response = None self._close_ws_after_use = True - self.__update_params(model, voice, format, volume, speech_rate, - pitch_rate, seed, synthesis_type, instruction, language_hints, headers, callback, workspace, url, - additional_params) + self.__update_params( + model, + voice, + format, + volume, + speech_rate, + pitch_rate, + seed, + synthesis_type, + instruction, + language_hints, + headers, + callback, + workspace, + url, + additional_params, + ) def __send_str(self, data: str): - logger.debug('>>>send {}'.format(data)) + logger.debug(">>>send %s", data) self.ws.send(data) def __connect(self, timeout_seconds=5) -> None: """ Establish a connection to the Bailian WebSocket server, - which can be used to pre-establish the connection and reduce interaction latency. + which can be used to pre-establish the connection and reduce interaction latency. # noqa: E501 # pylint: disable=line-too-long If this function is not used to create the connection, - it will be established when you first send text via call or streaming_call. + it will be established when you first send text via call or streaming_call. # noqa: E501 Parameters: ----------- timeout: int - Throws TimeoutError exception if the connection is not established after times out seconds. + Throws TimeoutError exception if the connection is not established after times out seconds. # noqa: E501 # pylint: disable=line-too-long """ self.ws = websocket.WebSocketApp( self.url, - header=self.request.getWebsocketHeaders(headers=self.headers, - workspace=self.workspace), + header=self.request.getWebsocketHeaders( + headers=self.headers, + workspace=self.workspace, + ), on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, @@ -316,13 +347,15 @@ def __connect(self, timeout_seconds=5) -> None: self.thread.start() # 等待连接建立 start_time = time.time() - while (not (self.ws.sock and self.ws.sock.connected) - and (time.time() - start_time) < timeout_seconds): + while ( + not (self.ws.sock and self.ws.sock.connected) + and (time.time() - start_time) < timeout_seconds + ): time.sleep(0.1) # 短暂休眠,避免密集轮询 if not (self.ws.sock and self.ws.sock.connected): raise TimeoutError( - 'websocket connection could not established within 5s. ' - 'Please check your network connection, firewall settings, or server status.' + "websocket connection could not established within 5s. " + "Please check your network connection, firewall settings, or server status.", # noqa: E501 # pylint: disable=line-too-long ) def __is_connected(self) -> bool: @@ -336,7 +369,7 @@ def __is_connected(self) -> bool: return False return True - def __reset(self): + def __reset(self): # pylint: disable=unused-private-member self.start_event.clear() self.complete_event.clear() self._stopped.clear() @@ -352,7 +385,7 @@ def __reset(self): self._recv_audio_length = 0 self.last_response = None - def __update_params( + def __update_params( # pylint: disable=redefined-builtin self, model, voice, @@ -372,25 +405,25 @@ def __update_params( close_ws_after_use=True, ): if model is None: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") if format is None: - raise InputRequired('format is required!') + raise InputRequired("format is required!") if url is None: url = dashscope.base_websocket_api_url self.url = url self.apikey = dashscope.api_key if self.apikey is None: - raise InputRequired('apikey is required!') + raise InputRequired("apikey is required!") self.headers = headers self.workspace = workspace self.additional_params = additional_params self.model = model self.voice = voice self.aformat = format.format - if (self.aformat == 'DEFAULT'): - self.aformat = 'mp3' + if self.aformat == "DEFAULT": + self.aformat = "mp3" self.sample_rate = format.sample_rate - if (self.sample_rate == 0): + if self.sample_rate == 0: self.sample_rate = 22050 self.callback = callback @@ -402,37 +435,41 @@ def __update_params( voice=voice, format=format.format, sample_rate=format.sample_rate, - bit_rate = format.bit_rate, + bit_rate=format.bit_rate, volume=volume, speech_rate=speech_rate, pitch_rate=pitch_rate, seed=seed, synthesis_type=synthesis_type, instruction=instruction, - language_hints=language_hints + language_hints=language_hints, ) self.last_request_id = self.request.task_id self._close_ws_after_use = close_ws_after_use def __str__(self): - return '[SpeechSynthesizer {} desc] model:{}, voice:{}, format:{}, sample_rate:{}, connected:{}'.format( - self.__hash__(), self.model, self.voice, self.aformat, - self.sample_rate, self.__is_connected()) + # pylint: disable=line-too-long + return ( + f"[SpeechSynthesizer {self.__hash__()} desc] " + f"model:{self.model}, voice:{self.voice}, " + f"format:{self.aformat}, sample_rate:{self.sample_rate}, " + f"connected:{self.__is_connected()}" + ) - def __start_stream(self, ): + def __start_stream(self): self._start_stream_timestamp = time.time() * 1000 self._first_package_timestamp = -1 self._recv_audio_length = 0 if self.callback is None: - raise InputRequired('callback is required!') + raise InputRequired("callback is required!") # reset inner params self._stopped.clear() - self._stream_data = [''] + self._stream_data = [""] self._worker = None self._audio_data: bytes = None if self._is_started: - raise InvalidTask('task has already started.') + raise InvalidTask("task has already started.") # 建立ws连接 if self.ws is None: self.__connect(5) @@ -440,23 +477,24 @@ def __start_stream(self, ): request = self.request.getStartRequest(self.additional_params) self.__send_str(request) if not self.start_event.wait(10): - raise TimeoutError('start speech synthesizer failed within 5s.') + raise TimeoutError("start speech synthesizer failed within 5s.") self._is_started = True if self.callback: self.callback.on_open() def __submit_text(self, text): if not self._is_started: - raise InvalidTask('speech synthesizer has not been started.') + raise InvalidTask("speech synthesizer has not been started.") if self._stopped.is_set(): - raise InvalidTask('speech synthesizer task has stopped.') + raise InvalidTask("speech synthesizer task has stopped.") request = self.request.getContinueRequest(text) self.__send_str(request) + # pylint: disable=useless-return def streaming_call(self, text: str): """ - Streaming input mode: You can call the stream_call function multiple times to send text. + Streaming input mode: You can call the stream_call function multiple times to send text. # noqa: E501 # pylint: disable=line-too-long A session will be created on the first call. The session ends after calling streaming_complete. Parameters: @@ -478,22 +516,24 @@ def streaming_complete(self, complete_timeout_millis=600000): Parameters: ----------- complete_timeout_millis: int - Throws TimeoutError exception if it times out. If the timeout is not None + Throws TimeoutError exception if it times out. If the timeout is not None # noqa: E501 and greater than zero, it will wait for the corresponding number of milliseconds; otherwise, it will wait indefinitely. """ if not self._is_started: - raise InvalidTask('speech synthesizer has not been started.') + raise InvalidTask("speech synthesizer has not been started.") if self._stopped.is_set(): - raise InvalidTask('speech synthesizer task has stopped.') + raise InvalidTask("speech synthesizer task has stopped.") request = self.request.getFinishRequest() self.__send_str(request) if complete_timeout_millis is not None and complete_timeout_millis > 0: - if not self.complete_event.wait(timeout=complete_timeout_millis / - 1000): + if not self.complete_event.wait( + timeout=complete_timeout_millis / 1000, + ): raise TimeoutError( - 'speech synthesizer wait for complete timeout {}ms'.format( - complete_timeout_millis)) + f"speech synthesizer wait for complete timeout " + f"{complete_timeout_millis}ms", + ) else: self.complete_event.wait() if self._close_ws_after_use: @@ -505,7 +545,7 @@ def __waiting_for_complete(self, timeout): if timeout is not None and timeout > 0: if not self.complete_event.wait(timeout=timeout / 1000): raise TimeoutError( - f'speech synthesizer wait for complete timeout {timeout}ms' + f"speech synthesizer wait for complete timeout {timeout}ms", # noqa: E501 ) else: self.complete_event.wait() @@ -516,26 +556,28 @@ def __waiting_for_complete(self, timeout): def async_streaming_complete(self, complete_timeout_millis=600000): """ - Asynchronously stop the streaming input speech synthesis task, returns immediately. - You need to listen and handle the STREAM_INPUT_TTS_EVENT_SYNTHESIS_COMPLETE event in the on_event callback. + Asynchronously stop the streaming input speech synthesis task, returns immediately. # noqa: E501 # pylint: disable=line-too-long + You need to listen and handle the STREAM_INPUT_TTS_EVENT_SYNTHESIS_COMPLETE event in the on_event callback. # noqa: E501 # pylint: disable=line-too-long Do not destroy the object and callback before this event. Parameters: ----------- complete_timeout_millis: int - Throws TimeoutError exception if it times out. If the timeout is not None + Throws TimeoutError exception if it times out. If the timeout is not None # noqa: E501 and greater than zero, it will wait for the corresponding number of milliseconds; otherwise, it will wait indefinitely. """ if not self._is_started: - raise InvalidTask('speech synthesizer has not been started.') + raise InvalidTask("speech synthesizer has not been started.") if self._stopped.is_set(): - raise InvalidTask('speech synthesizer task has stopped.') + raise InvalidTask("speech synthesizer task has stopped.") request = self.request.getFinishRequest() self.__send_str(request) - thread = threading.Thread(target=self.__waiting_for_complete, - args=(complete_timeout_millis, )) + thread = threading.Thread( + target=self.__waiting_for_complete, + args=(complete_timeout_millis,), + ) thread.start() def streaming_cancel(self): @@ -545,7 +587,7 @@ def streaming_cancel(self): """ if not self._is_started: - raise InvalidTask('speech synthesizer has not been started.') + raise InvalidTask("speech synthesizer has not been started.") if self._stopped.is_set(): return request = self.request.getFinishRequest() @@ -555,14 +597,18 @@ def streaming_cancel(self): self.complete_event.set() # 监听消息的回调函数 - def on_message(self, ws, message): + def on_message( # pylint: disable=unused-argument,too-many-branches + self, + ws, + message, + ): if isinstance(message, str): - logger.debug('<< 100: - raise ValueError('max_size must be less than 100') + raise ValueError("max_size must be less than 100") self._pool = [] # 如果重连中,则会将avaliable置为False,避免被使用 self._avaliable = [] self._pool_size = max_size - for i in range(self._pool_size): + for i in range(self._pool_size): # pylint: disable=unused-variable synthesizer = self.__get_default_synthesizer() tmpPoolObject = self.PoolObject(synthesizer) tmpPoolObject.synthesizer._SpeechSynthesizer__connect() @@ -734,29 +796,38 @@ def __init__(self, self._lock = threading.Lock() self._stop = False self._stop_lock = threading.Lock() - self._working_thread = threading.Thread(target=self.__auto_reconnect, - args=()) + self._working_thread = threading.Thread( + target=self.__auto_reconnect, + args=(), + ) self._working_thread.start() def __get_default_synthesizer(self) -> SpeechSynthesizer: - return SpeechSynthesizer(model=self.DEFAULT_MODEL, - voice=self.DEFAULT_VOICE, - url=self.DEFAULT_URL, - headers=self.DEFAUTL_HEADERS, - workspace=self.DEFAULT_WORKSPACE) + return SpeechSynthesizer( + model=self.DEFAULT_MODEL, + voice=self.DEFAULT_VOICE, + url=self.DEFAULT_URL, + headers=self.DEFAUTL_HEADERS, + workspace=self.DEFAULT_WORKSPACE, + ) def __get_reconnect_interval(self): return self.DEFAULT_RECONNECT_INTERVAL + random.random() * 10 - 5 def __auto_reconnect(self): logger.debug( - 'speech synthesizer object pool auto reconnect thread start') + "speech synthesizer object pool auto reconnect thread start", + ) while True: objects_need_to_connect = [] objects_need_to_renew = [] - logger.debug('scanning queue borr: {}/{} remain: {}/{}'.format( - self._borrowed_object_num, self._pool_size, - self._remain_object_num, self._pool_size)) + logger.debug( + "scanning queue borr: %s/%s remain: %s/%s", + self._borrowed_object_num, + self._pool_size, + self._remain_object_num, + self._pool_size, + ) with self._lock: if self._stop: return @@ -767,26 +838,34 @@ def __auto_reconnect(self): if poolObject.connect_time == -1: objects_need_to_connect.append(poolObject) self._avaliable[idx] = False - elif (not poolObject.synthesizer. - _SpeechSynthesizer__is_connected()) or ( - current_time - poolObject.connect_time > - self.__get_reconnect_interval()): + elif ( + # Access private method for connection check + not poolObject.synthesizer._SpeechSynthesizer__is_connected() # pylint: disable=protected-access # noqa: E501 + ) or ( + current_time - poolObject.connect_time + > self.__get_reconnect_interval() + ): objects_need_to_renew.append(poolObject) self._avaliable[idx] = False for poolObject in objects_need_to_connect: logger.info( - '[SpeechSynthesizerObjectPool] pre-connect new synthesizer' + "[SpeechSynthesizerObjectPool] pre-connect new synthesizer", # noqa: E501 ) - poolObject.synthesizer._SpeechSynthesizer__connect() + # Access private method to establish connection + poolObject.synthesizer._SpeechSynthesizer__connect() # pylint: disable=protected-access # noqa: E501 poolObject.connect_time = time.time() for poolObject in objects_need_to_renew: + # pylint: disable=line-too-long logger.info( - '[SpeechSynthesizerObjectPool] renew synthesizer after {} s' - .format(current_time - poolObject.connect_time)) + "[SpeechSynthesizerObjectPool] renew synthesizer after %s s", # noqa: E501 + current_time - poolObject.connect_time, + ) poolObject.synthesizer = self.__get_default_synthesizer() - poolObject.synthesizer._SpeechSynthesizer__connect() + # Access private method to establish connection + poolObject.synthesizer._SpeechSynthesizer__connect() # pylint: disable=protected-access # noqa: E501 poolObject.connect_time = time.time() with self._lock: + # pylint: disable=consider-using-enumerate for i in range(len(self._avaliable)): self._avaliable[i] = True time.sleep(1) @@ -796,14 +875,14 @@ def shutdown(self): This is a ThreadSafe Method. destroy the object pool """ - logger.debug('[SpeechSynthesizerObjectPool] start shutdown') + logger.debug("[SpeechSynthesizerObjectPool] start shutdown") with self._lock: self._stop = True self._pool = [] self._working_thread.join() - logger.debug('[SpeechSynthesizerObjectPool] shutdown complete') + logger.debug("[SpeechSynthesizerObjectPool] shutdown complete") - def borrow_synthesizer( + def borrow_synthesizer( # pylint: disable=unused-argument,redefined-builtin # noqa: E501 self, model, voice, @@ -829,14 +908,16 @@ def borrow_synthesizer( If there is no synthesizer object in the pool, a new synthesizer object will be created and returned. """ - logger.debug('[SpeechSynthesizerObjectPool] get synthesizer') + logger.debug("[SpeechSynthesizerObjectPool] get synthesizer") synthesizer: SpeechSynthesizer = None with self._lock: # 遍历对象池,如果存在预建连的对象,则返回 for idx, poolObject in enumerate(self._pool): - if self._avaliable[ - idx] and poolObject.synthesizer._SpeechSynthesizer__is_connected( - ): + if ( + self._avaliable[idx] + # Access private method for connection check + and poolObject.synthesizer._SpeechSynthesizer__is_connected() # pylint: disable=protected-access # noqa: E501 + ): synthesizer = poolObject.synthesizer self._borrowed_object_num += 1 self._remain_object_num -= 1 @@ -848,31 +929,45 @@ def borrow_synthesizer( if synthesizer is None: synthesizer = self.__get_default_synthesizer() logger.warning( - '[SpeechSynthesizerObjectPool] object pool is exausted, create new synthesizer' + "[SpeechSynthesizerObjectPool] object pool is exausted, create new synthesizer", # noqa: E501 # pylint: disable=line-too-long ) - synthesizer._SpeechSynthesizer__reset() - synthesizer._SpeechSynthesizer__update_params(model, voice, format, - volume, speech_rate, - pitch_rate, seed, synthesis_type, instruction, - language_hints, self.DEFAUTL_HEADERS, - callback, self.DEFAULT_WORKSPACE, self.DEFAULT_URL, - additional_params, False) + # Access private methods to reset and update synthesizer params + synthesizer._SpeechSynthesizer__reset() # pylint: disable=protected-access # noqa: E501 + synthesizer._SpeechSynthesizer__update_params( # pylint: disable=protected-access # noqa: E501 + model, + voice, + format, + volume, + speech_rate, + pitch_rate, + seed, + synthesis_type, + instruction, + language_hints, + self.DEFAUTL_HEADERS, + callback, + self.DEFAULT_WORKSPACE, + self.DEFAULT_URL, + additional_params, + False, + ) return synthesizer - def return_synthesizer(self, synthesizer) -> bool: + # pylint: disable=inconsistent-return-statements + def return_synthesizer(self, synthesizer) -> bool: # type: ignore[return] """ This is a ThreadSafe Method. return a synthesizer object back to the pool. """ if not isinstance(synthesizer, SpeechSynthesizer): logger.error( - '[SpeechSynthesizerObjectPool] return_synthesizer: synthesizer is not a SpeechSynthesizer object' + "[SpeechSynthesizerObjectPool] return_synthesizer: synthesizer is not a SpeechSynthesizer object", # noqa: E501 # pylint: disable=line-too-long ) return False with self._lock: if self._borrowed_object_num <= 0: logger.debug( - '[SpeechSynthesizerObjectPool] pool is full, drop returned object' + "[SpeechSynthesizerObjectPool] pool is full, drop returned object", # noqa: E501 # pylint: disable=line-too-long ) return False poolObject = self.PoolObject(synthesizer) @@ -882,5 +977,5 @@ def return_synthesizer(self, synthesizer) -> bool: self._borrowed_object_num -= 1 self._remain_object_num += 1 logger.debug( - '[SpeechSynthesizerObjectPool] return synthesizer back to pool' + "[SpeechSynthesizerObjectPool] return synthesizer back to pool", # noqa: E501 ) diff --git a/dashscope/cli.py b/dashscope/cli.py index 9e7c47e..e4da85d 100644 --- a/dashscope/cli.py +++ b/dashscope/cli.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +# -*- coding: utf-8 -*- import argparse import json import sys @@ -8,14 +9,19 @@ import dashscope from dashscope.aigc import Generation -from dashscope.common.constants import (DeploymentStatus, FilePurpose, - TaskStatus) +from dashscope.common.constants import ( + DeploymentStatus, + FilePurpose, + TaskStatus, +) from dashscope.utils.oss_utils import OssUtils def print_failed_message(rsp): - print('Failed, request_id: %s, status_code: %s, code: %s, message: %s' % - (rsp.request_id, rsp.status_code, rsp.code, rsp.message)) + print( + f"Failed, request_id: {rsp.request_id}, status_code: " + f"{rsp.status_code}, code: {rsp.code}, message: {rsp.message}", + ) def text_generation(args): @@ -40,13 +46,13 @@ class FineTunes: def call(cls, args): params = {} if args.n_epochs is not None: - params['n_epochs'] = args.n_epochs + params["n_epochs"] = args.n_epochs if args.batch_size is not None: - params['batch_size'] = args.batch_size + params["batch_size"] = args.batch_size if args.learning_rate is not None: - params['learning_rate'] = args.learning_rate + params["learning_rate"] = args.learning_rate if args.prompt_loss is not None: - params['prompt_loss'] = args.prompt_loss + params["prompt_loss"] = args.prompt_loss if args.params: params.update(args.params) @@ -55,11 +61,14 @@ def call(cls, args): training_file_ids=args.training_file_ids, validation_file_ids=args.validation_file_ids, mode=args.mode, - hyper_parameters=params) + hyper_parameters=params, + ) if rsp.status_code == HTTPStatus.OK: - print('Create fine-tune job success, job_id: %s' % - rsp.output['job_id']) - cls.wait(rsp.output['job_id']) + print( + f"Create fine-tune job success, job_id: " + f"{rsp.output['job_id']}", + ) + cls.wait(rsp.output["job_id"]) else: print_failed_message(rsp) @@ -69,65 +78,78 @@ def wait(cls, job_id): while True: rsp = dashscope.FineTunes.get(job_id) if rsp.status_code == HTTPStatus.OK: - if rsp.output['status'] == TaskStatus.FAILED: - print('Fine-tune FAILED!') + if rsp.output["status"] == TaskStatus.FAILED: + print("Fine-tune FAILED!") break - elif rsp.output['status'] == TaskStatus.CANCELED: - print('Fine-tune task CANCELED') + if ( # pylint: disable=no-else-break + rsp.output["status"] == TaskStatus.CANCELED + ): + print("Fine-tune task CANCELED") break - elif rsp.output['status'] == TaskStatus.RUNNING: + if rsp.output["status"] == TaskStatus.RUNNING: print( - 'Fine-tuning is RUNNING, start get output stream.') + "Fine-tuning is RUNNING, start get output stream.", + ) cls.stream_events(job_id) - elif rsp.output['status'] == TaskStatus.SUCCEEDED: - print('Fine-tune task success, fine-tuned model:%s' % - rsp.output['finetuned_output']) + elif rsp.output["status"] == TaskStatus.SUCCEEDED: + print( + f"Fine-tune task success, fine-tuned model:" + f"{rsp.output['finetuned_output']}", + ) break else: - print('The fine-tune task is: %s' % - rsp.output['status']) + print( + f"The fine-tune task is: {rsp.output['status']}", + ) time.sleep(30) else: print_failed_message(rsp) except Exception: print( - 'You can stream output via: dashscope fine_tunes.stream -j %s' - % job_id) + f"You can stream output via: dashscope fine_tunes.stream -j " + f"{job_id}", + ) @classmethod def get(cls, args): rsp = dashscope.FineTunes.get(args.job) if rsp.status_code == HTTPStatus.OK: - if rsp.output['status'] == TaskStatus.FAILED: - print('Fine-tune failed!') - elif rsp.output['status'] == TaskStatus.CANCELED: - print('Fine-tune task canceled') - elif rsp.output['status'] == TaskStatus.SUCCEEDED: - print('Fine-tune task success, fine-tuned model : %s' % - rsp.output['finetuned_output']) + if rsp.output["status"] == TaskStatus.FAILED: + print("Fine-tune failed!") + elif rsp.output["status"] == TaskStatus.CANCELED: + print("Fine-tune task canceled") + elif rsp.output["status"] == TaskStatus.SUCCEEDED: + print( + f"Fine-tune task success, fine-tuned model : " + f"{rsp.output['finetuned_output']}", + ) else: - print('The fine-tune task is: %s' % rsp.output['status']) + print(f"The fine-tune task is: {rsp.output['status']}") else: print_failed_message(rsp) @classmethod def list(cls, args): - rsp = dashscope.FineTunes.list(page=args.start_page, - page_size=args.page_size) + rsp = dashscope.FineTunes.list( + page=args.start_page, + page_size=args.page_size, + ) if rsp.status_code == HTTPStatus.OK: if rsp.output is not None: - for job in rsp.output['jobs']: - if job['status'] == TaskStatus.SUCCEEDED: + for job in rsp.output["jobs"]: + if job["status"] == TaskStatus.SUCCEEDED: print( - 'job: %s, status: %s, base model: %s, fine-tuned model: %s' # noqa E501 - % # noqa - (job['job_id'], job['status'], job['model'], - job['finetuned_output'])) + f"job: {job['job_id']}, status: {job['status']}, " + f"base model: {job['model']}, " + f"fine-tuned model: {job['finetuned_output']}", + ) else: - print('job: %s, status: %s, base model: %s' % - (job['job_id'], job['status'], job['model'])) + print( + f"job: {job['job_id']}, status: {job['status']}, " + f"base model: {job['model']}", + ) else: - print('There is no fine-tuned model.') + print("There is no fine-tuned model.") else: print_failed_message(rsp) @@ -136,12 +158,14 @@ def stream_events(cls, job_id): # check job status if job is completed, get log. rsp = dashscope.FineTunes.get(job_id) if rsp.status_code == HTTPStatus.OK: - if rsp.output['status'] in [ - TaskStatus.FAILED, TaskStatus.CANCELED, - TaskStatus.SUCCEEDED + if rsp.output["status"] in [ + TaskStatus.FAILED, + TaskStatus.CANCELED, + TaskStatus.SUCCEEDED, ]: - print('Fine-tune job: %s is %s' % - (job_id, rsp.output['status'])) + print( + f"Fine-tune job: {job_id} is {rsp.output['status']}", + ) cls.log(job_id) return else: @@ -157,8 +181,9 @@ def stream_events(cls, job_id): print_failed_message(rsp) except Exception: print( - 'You can stream output via: dashscope fine-tunes.stream -j %s' - % job_id) + f"You can stream output via: dashscope fine-tunes.stream -j " + f"{job_id}", + ) @classmethod def events(cls, args): @@ -171,12 +196,11 @@ def log(cls, job_id): while True: rsp = dashscope.FineTunes.logs(job_id, offset=start, line=n_line) if rsp.status_code == HTTPStatus.OK: - for line in rsp.output['logs']: + for line in rsp.output["logs"]: print(line) - if rsp.output['total'] < n_line: + if rsp.output["total"] < n_line: break - else: - start += n_line + start += n_line # pylint: disable=no-else-break else: print_failed_message(rsp) @@ -184,7 +208,7 @@ def log(cls, job_id): def cancel(cls, args): rsp = dashscope.FineTunes.cancel(args.job) if rsp.status_code == HTTPStatus.OK: - print('Cancel fine-tune job: %s success!') + print("Cancel fine-tune job: %s success!") else: print_failed_message(rsp) @@ -192,51 +216,64 @@ def cancel(cls, args): def delete(cls, args): rsp = dashscope.FineTunes.delete(args.job) if rsp.status_code == HTTPStatus.OK: - print('fine_tune job: %s delete success' % args.job) + print(f"fine_tune job: {args.job} delete success") else: print_failed_message(rsp) + class Oss: @classmethod def upload(cls, args): - print('Start oss.upload: model=%s, file=%s, api_key=%s' % (args.model, args.file, args.api_key)) + print( + f"Start oss.upload: model={args.model}, file={args.file}, " + f"api_key={args.api_key}", + ) if not args.file or not args.model: - print('Please specify the model and file path') + print("Please specify the model and file path") return file_path = os.path.expanduser(args.file) if not os.path.exists(file_path): - print('File %s does not exist' % file_path) + print(f"File {file_path} does not exist") return - api_key = os.environ.get('DASHSCOPE_API_KEY', args.api_key) + api_key = os.environ.get("DASHSCOPE_API_KEY", args.api_key) if not api_key: - print('Please set your DashScope API key as environment variable ' - 'DASHSCOPE_API_KEY or pass it as argument by -k/--api_key') + print( + "Please set your DashScope API key as environment variable " + "DASHSCOPE_API_KEY or pass it as argument by -k/--api_key", + ) return - oss_url, _ = OssUtils.upload(model=args.model, - file_path=file_path, - api_key=api_key, - base_address=args.base_url) + oss_url, _ = OssUtils.upload( + model=args.model, + file_path=file_path, + api_key=api_key, + base_address=args.base_url, + ) if not oss_url: - print('Failed to upload file: %s' % file_path) + print(f"Failed to upload file: {file_path}") return - print('Uploaded oss url: %s' % oss_url) + print(f"Uploaded oss url: {oss_url}") + class Files: @classmethod def upload(cls, args): - rsp = dashscope.Files.upload(file_path=args.file, - purpose=args.purpose, - description=args.description, - base_address=args.base_url) + rsp = dashscope.Files.upload( + file_path=args.file, + purpose=args.purpose, + description=args.description, + base_address=args.base_url, + ) print(rsp) if rsp.status_code == HTTPStatus.OK: - print('Upload success, file id: %s' % - rsp.output['uploaded_files'][0]['file_id']) + print( + f"Upload success, file id: " + f"{rsp.output['uploaded_files'][0]['file_id']}", + ) else: print_failed_message(rsp) @@ -245,22 +282,30 @@ def get(cls, args): rsp = dashscope.Files.get(file_id=args.id, base_address=args.base_url) if rsp.status_code == HTTPStatus.OK: if rsp.output: - print('file info:\n%s' % json.dumps(rsp.output, ensure_ascii=False, indent=4)) + print( + f"file info:\n" + f"{json.dumps(rsp.output, ensure_ascii=False, indent=4)}", + ) else: - print('There is no uploaded file.') + print("There is no uploaded file.") else: print_failed_message(rsp) @classmethod def list(cls, args): - rsp = dashscope.Files.list(page=args.start_page, - page_size=args.page_size, - base_address=args.base_url) + rsp = dashscope.Files.list( + page=args.start_page, + page_size=args.page_size, + base_address=args.base_url, + ) if rsp.status_code == HTTPStatus.OK: if rsp.output: - print('file list info:\n%s' % json.dumps(rsp.output, ensure_ascii=False, indent=4)) + print( + f"file list info:\n" + f"{json.dumps(rsp.output, ensure_ascii=False, indent=4)}", + ) else: - print('There is no uploaded files.') + print("There is no uploaded files.") else: print_failed_message(rsp) @@ -268,7 +313,7 @@ def list(cls, args): def delete(cls, args): rsp = dashscope.Files.delete(args.id, base_address=args.base_url) if rsp.status_code == HTTPStatus.OK: - print('Delete success') + print("Delete success") else: print_failed_message(rsp) @@ -276,33 +321,41 @@ def delete(cls, args): class Deployments: @classmethod def call(cls, args): - rsp = dashscope.Deployments.call(model=args.model, - capacity=args.capacity, - suffix=args.suffix) + rsp = dashscope.Deployments.call( + model=args.model, + capacity=args.capacity, + suffix=args.suffix, + ) if rsp.status_code == HTTPStatus.OK: - deployed_model = rsp.output['deployed_model'] - print('Create model: %s deployment' % deployed_model) + deployed_model = rsp.output["deployed_model"] + print(f"Create model: {deployed_model} deployment") try: while True: # wait for deployment ok. status = dashscope.Deployments.get(deployed_model) if status.status_code == HTTPStatus.OK: - if status.output['status'] in [ - DeploymentStatus.PENDING, - DeploymentStatus.DEPLOYING + if status.output["status"] in [ + DeploymentStatus.PENDING, + DeploymentStatus.DEPLOYING, ]: time.sleep(30) - print('Deployment %s is %s' % - (deployed_model, status.output['status'])) + print( + f"Deployment {deployed_model} is " + f"{status.output['status']}", + ) else: - print('Deployment: %s status: %s' % - (deployed_model, status.output['status'])) + print( + f"Deployment: {deployed_model} status: " + f"{status.output['status']}", + ) break else: print_failed_message(rsp) except Exception: - print('You can get deployment status via: \ - dashscope deployments.get -d %s' % deployed_model) + print( + f"You can get deployment status via: " + f"dashscope deployments.get -d {deployed_model}", + ) else: print_failed_message(rsp) @@ -310,28 +363,39 @@ def call(cls, args): def get(cls, args): rsp = dashscope.Deployments.get(args.deploy) if rsp.status_code == HTTPStatus.OK: - print('Deployed model: %s capacity: %s status: %s' % - (rsp.output['deployed_model'], rsp.output['capacity'], - rsp.output['status'])) + print( + f"Deployed model: {rsp.output['deployed_model']} " + f"capacity: {rsp.output['capacity']} " + f"status: {rsp.output['status']}", + ) else: print_failed_message(rsp) @classmethod def list(cls, args): - rsp = dashscope.Deployments.list(page_no=args.start_page, - page_size=args.page_size) + rsp = dashscope.Deployments.list( + page_no=args.start_page, + page_size=args.page_size, + ) if rsp.status_code == HTTPStatus.OK: if rsp.output is not None: - if 'deployments' not in rsp.output or len( - rsp.output['deployments']) == 0: - print('There is no deployed model!') + if ( + "deployments" not in rsp.output + or len( + rsp.output["deployments"], + ) + == 0 + ): + print("There is no deployed model!") return - for deployment in rsp.output['deployments']: - print('Deployed_model: %s, model: %s, status: %s' % - (deployment['deployed_model'], - deployment['model_name'], deployment['status'])) + for deployment in rsp.output["deployments"]: + print( + f"Deployed_model: {deployment['deployed_model']}, " + f"model: {deployment['model_name']}, " + f"status: {deployment['status']}", + ) else: - print('There is no deployed model.') + print("There is no deployed model.") else: print_failed_message(rsp) @@ -340,15 +404,17 @@ def update(cls, args): rsp = dashscope.Deployments.update(args.deployed_model, args.version) if rsp.status_code == HTTPStatus.OK: if rsp.output is not None: - if 'deployments' not in rsp.output: - print('There is no deployed model!') + if "deployments" not in rsp.output: + print("There is no deployed model!") return - for deployment in rsp.output['deployments']: - print('Deployed_model: %s, model: %s, status: %s' % - (deployment['deployed_model'], - deployment['model_name'], deployment['status'])) + for deployment in rsp.output["deployments"]: + print( + f"Deployed_model: {deployment['deployed_model']}, " + f"model: {deployment['model_name']}, " + f"status: {deployment['status']}", + ) else: - print('There is no deployed model.') + print("There is no deployed model.") else: print_failed_message(rsp) @@ -357,11 +423,13 @@ def scale(cls, args): rsp = dashscope.Deployments.scale(args.deployed_model, args.capacity) if rsp.status_code == HTTPStatus.OK: if rsp.output is not None: - print('Deployed_model: %s, model: %s, status: %s' % - (rsp.output['deployed_model'], rsp.output['model_name'], - rsp.output['status'])) + print( + f"Deployed_model: {rsp.output['deployed_model']}, " + f"model: {rsp.output['model_name']}, " + f"status: {rsp.output['status']}", + ) else: - print('There is no deployed model.') + print("There is no deployed model.") else: print_failed_message(rsp) @@ -369,7 +437,7 @@ def scale(cls, args): def delete(cls, args): rsp = dashscope.Deployments.delete(args.deploy) if rsp.status_code == HTTPStatus.OK: - print('Deployed model: %s delete success' % args.deploy) + print(f"Deployed model: {args.deploy} delete success") else: print_failed_message(rsp) @@ -377,305 +445,374 @@ def delete(cls, args): # from: https://gist.github.com/vadimkantorov/37518ff88808af840884355c845049ea class ParseKVAction(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): - setattr(namespace, self.dest, dict()) + # pylint: disable=use-dict-literal + setattr( + namespace, + self.dest, + dict(), + ) for each in values: try: - key, value = each.split('=') + key, value = each.split("=") getattr(namespace, self.dest)[key] = value except ValueError as ex: - message = '\nTraceback: {}'.format(ex) - message += "\nError on '{}' || It should be 'key=value'".format( - each) + message = f"\nTraceback: {ex}" + message += f"\nError on '{each}' || It should be 'key=value'" raise argparse.ArgumentError(self, str(message)) +# pylint: disable=too-many-statements def main(): parser = argparse.ArgumentParser( - prog='dashscope', description='dashscope command line tools.') - parser.add_argument('-k', '--api-key', help='Dashscope API key.') - sub_parsers = parser.add_subparsers(help='Api subcommands') - text_generation_parser = sub_parsers.add_parser('generation.call') - text_generation_parser.add_argument('-p', - '--prompt', - type=str, - required=True, - help='Input prompt') - text_generation_parser.add_argument('-m', - '--model', - type=str, - required=True, - help='The model to call.') - text_generation_parser.add_argument('--history', - type=str, - required=False, - help='The history of the request.') - text_generation_parser.add_argument('-s', - '--stream', - default=False, - action='store_true', - help='Use stream mode default false.') + prog="dashscope", + description="dashscope command line tools.", + ) + parser.add_argument("-k", "--api-key", help="Dashscope API key.") + sub_parsers = parser.add_subparsers(help="Api subcommands") + text_generation_parser = sub_parsers.add_parser("generation.call") + text_generation_parser.add_argument( + "-p", + "--prompt", + type=str, + required=True, + help="Input prompt", + ) + text_generation_parser.add_argument( + "-m", + "--model", + type=str, + required=True, + help="The model to call.", + ) + text_generation_parser.add_argument( + "--history", + type=str, + required=False, + help="The history of the request.", + ) + text_generation_parser.add_argument( + "-s", + "--stream", + default=False, + action="store_true", + help="Use stream mode default false.", + ) text_generation_parser.set_defaults(func=text_generation) - fine_tune_call = sub_parsers.add_parser('fine_tunes.call') + fine_tune_call = sub_parsers.add_parser("fine_tunes.call") fine_tune_call.add_argument( - '-t', - '--training_file_ids', + "-t", + "--training_file_ids", required=True, - nargs='+', - help='Training file ids which upload by File command.') + nargs="+", + help="Training file ids which upload by File command.", + ) fine_tune_call.add_argument( - '-v', - '--validation_file_ids', + "-v", + "--validation_file_ids", required=False, - nargs='+', + nargs="+", default=[], - help='Validation file ids which upload by File command.') - fine_tune_call.add_argument('-m', - '--model', - type=str, - required=True, - help='The based model to start fine-tune.') + help="Validation file ids which upload by File command.", + ) fine_tune_call.add_argument( - '--mode', + "-m", + "--model", type=str, + required=True, + help="The based model to start fine-tune.", + ) + fine_tune_call.add_argument( + "--mode", + type=str, + required=False, + choices=["sft", "efficient_sft"], + help="Select fine-tune mode sft or efficient_sft", + ) + fine_tune_call.add_argument( + "-e", + "--n_epochs", + type=int, required=False, - choices=['sft', 'efficient_sft'], - help='Select fine-tune mode sft or efficient_sft') - fine_tune_call.add_argument('-e', - '--n_epochs', - type=int, - required=False, - help='How many epochs to fine-tune.') - fine_tune_call.add_argument('-b', - '--batch_size', - type=int, - required=False, - help='How big is batch_size.') - fine_tune_call.add_argument('-l', - '--learning_rate', - type=float, - required=False, - help='The fine-tune learning rate.') - fine_tune_call.add_argument('-p', - '--prompt_loss', - type=float, - required=False, - help='The fine-tune prompt loss.') + help="How many epochs to fine-tune.", + ) + fine_tune_call.add_argument( + "-b", + "--batch_size", + type=int, + required=False, + help="How big is batch_size.", + ) fine_tune_call.add_argument( - '--hyper_parameters', - nargs='+', - dest='params', + "-l", + "--learning_rate", + type=float, + required=False, + help="The fine-tune learning rate.", + ) + fine_tune_call.add_argument( + "-p", + "--prompt_loss", + type=float, + required=False, + help="The fine-tune prompt loss.", + ) + fine_tune_call.add_argument( + "--hyper_parameters", + nargs="+", + dest="params", action=ParseKVAction, - help='Extra hyper parameters accepts by key1=value1 key2=value2', - metavar='KEY1=VALUE1') + help="Extra hyper parameters accepts by key1=value1 key2=value2", + metavar="KEY1=VALUE1", + ) fine_tune_call.set_defaults(func=FineTunes.call) - fine_tune_get = sub_parsers.add_parser('fine_tunes.get') - fine_tune_get.add_argument('-j', - '--job', - type=str, - required=True, - help='The fine-tune job id.') + fine_tune_get = sub_parsers.add_parser("fine_tunes.get") + fine_tune_get.add_argument( + "-j", + "--job", + type=str, + required=True, + help="The fine-tune job id.", + ) fine_tune_get.set_defaults(func=FineTunes.get) - fine_tune_delete = sub_parsers.add_parser('fine_tunes.delete') - fine_tune_delete.add_argument('-j', - '--job', - type=str, - required=True, - help='The fine-tune job id.') + fine_tune_delete = sub_parsers.add_parser("fine_tunes.delete") + fine_tune_delete.add_argument( + "-j", + "--job", + type=str, + required=True, + help="The fine-tune job id.", + ) fine_tune_delete.set_defaults(func=FineTunes.delete) - fine_tune_stream = sub_parsers.add_parser('fine_tunes.stream') - fine_tune_stream.add_argument('-j', - '--job', - type=str, - required=True, - help='The fine-tune job id.') + fine_tune_stream = sub_parsers.add_parser("fine_tunes.stream") + fine_tune_stream.add_argument( + "-j", + "--job", + type=str, + required=True, + help="The fine-tune job id.", + ) fine_tune_stream.set_defaults(func=FineTunes.events) - fine_tune_list = sub_parsers.add_parser('fine_tunes.list') - fine_tune_list.add_argument('-s', - '--start_page', - type=int, - default=1, - help='Start of page, default 1') - fine_tune_list.add_argument('-p', - '--page_size', - type=int, - default=10, - help='The page size, default 10') + fine_tune_list = sub_parsers.add_parser("fine_tunes.list") + fine_tune_list.add_argument( + "-s", + "--start_page", + type=int, + default=1, + help="Start of page, default 1", + ) + fine_tune_list.add_argument( + "-p", + "--page_size", + type=int, + default=10, + help="The page size, default 10", + ) fine_tune_list.set_defaults(func=FineTunes.list) - fine_tune_cancel = sub_parsers.add_parser('fine_tunes.cancel') - fine_tune_cancel.add_argument('-j', - '--job', - type=str, - required=True, - help='The fine-tune job id.') + fine_tune_cancel = sub_parsers.add_parser("fine_tunes.cancel") + fine_tune_cancel.add_argument( + "-j", + "--job", + type=str, + required=True, + help="The fine-tune job id.", + ) fine_tune_cancel.set_defaults(func=FineTunes.cancel) - oss_upload = sub_parsers.add_parser('oss.upload') + oss_upload = sub_parsers.add_parser("oss.upload") oss_upload.add_argument( - '-f', - '--file', + "-f", + "--file", type=str, required=True, - help='The file path to upload', + help="The file path to upload", ) oss_upload.add_argument( - '-m', - '--model', + "-m", + "--model", type=str, required=True, - help='The model name', + help="The model name", ) oss_upload.add_argument( - '-k', - '--api_key', + "-k", + "--api_key", type=str, required=False, - help='The dashscope api key', + help="The dashscope api key", ) oss_upload.add_argument( - '-u', - '--base_url', + "-u", + "--base_url", type=str, - help='The base url.', + help="The base url.", required=False, ) oss_upload.set_defaults(func=Oss.upload) - file_upload = sub_parsers.add_parser('files.upload') + file_upload = sub_parsers.add_parser("files.upload") file_upload.add_argument( - '-f', - '--file', + "-f", + "--file", type=str, required=True, - help='The file path to upload', + help="The file path to upload", ) file_upload.add_argument( - '-p', - '--purpose', + "-p", + "--purpose", default=FilePurpose.fine_tune, const=FilePurpose.fine_tune, - nargs='?', - help='Purpose to upload file[fine-tune]', + nargs="?", + help="Purpose to upload file[fine-tune]", required=True, ) file_upload.add_argument( - '-d', - '--description', + "-d", + "--description", type=str, - help='The file description.', + help="The file description.", required=False, ) file_upload.add_argument( - '-u', - '--base_url', + "-u", + "--base_url", type=str, - help='The base url.', + help="The base url.", required=False, ) file_upload.set_defaults(func=Files.upload) - file_get = sub_parsers.add_parser('files.get') - file_get.add_argument('-i', - '--id', - type=str, - required=True, - help='The file ID') + file_get = sub_parsers.add_parser("files.get") file_get.add_argument( - '-u', - '--base_url', + "-i", + "--id", type=str, - help='The base url.', + required=True, + help="The file ID", + ) + file_get.add_argument( + "-u", + "--base_url", + type=str, + help="The base url.", required=False, ) file_get.set_defaults(func=Files.get) - file_delete = sub_parsers.add_parser('files.delete') - file_delete.add_argument('-i', - '--id', - type=str, - required=True, - help='The files ID') + file_delete = sub_parsers.add_parser("files.delete") + file_delete.add_argument( + "-i", + "--id", + type=str, + required=True, + help="The files ID", + ) file_delete.add_argument( - '-u', - '--base_url', + "-u", + "--base_url", type=str, - help='The base url.', + help="The base url.", required=False, ) file_delete.set_defaults(func=Files.delete) - file_list = sub_parsers.add_parser('files.list') - file_list.add_argument('-s', - '--start_page', - type=int, - default=1, - help='Start of page, default 1') - file_list.add_argument('-p', - '--page_size', - type=int, - default=10, - help='The page size, default 10') + file_list = sub_parsers.add_parser("files.list") file_list.add_argument( - '-u', - '--base_url', + "-s", + "--start_page", + type=int, + default=1, + help="Start of page, default 1", + ) + file_list.add_argument( + "-p", + "--page_size", + type=int, + default=10, + help="The page size, default 10", + ) + file_list.add_argument( + "-u", + "--base_url", type=str, - help='The base url.', + help="The base url.", required=False, ) file_list.set_defaults(func=Files.list) - deployments_call = sub_parsers.add_parser('deployments.call') - deployments_call.add_argument('-m', - '--model', - required=True, - help='The model ID') - deployments_call.add_argument('-s', - '--suffix', - required=False, - help=('The suffix of the deployment, \ - lower cased characters 8 chars max.')) - deployments_call.add_argument('-c', - '--capacity', - type=int, - required=False, - default=1, - help='The target capacity') + deployments_call = sub_parsers.add_parser("deployments.call") + deployments_call.add_argument( + "-m", + "--model", + required=True, + help="The model ID", + ) + deployments_call.add_argument( + "-s", + "--suffix", + required=False, + help=( + "The suffix of the deployment, \ + lower cased characters 8 chars max." + ), + ) + deployments_call.add_argument( + "-c", + "--capacity", + type=int, + required=False, + default=1, + help="The target capacity", + ) deployments_call.set_defaults(func=Deployments.call) - deployments_get = sub_parsers.add_parser('deployments.get') - deployments_get.add_argument('-d', - '--deploy', - required=True, - help='The deployed model.') + deployments_get = sub_parsers.add_parser("deployments.get") + deployments_get.add_argument( + "-d", + "--deploy", + required=True, + help="The deployed model.", + ) deployments_get.set_defaults(func=Deployments.get) - deployments_delete = sub_parsers.add_parser('deployments.delete') - deployments_delete.add_argument('-d', - '--deploy', - required=True, - help='The deployed model.') + deployments_delete = sub_parsers.add_parser("deployments.delete") + deployments_delete.add_argument( + "-d", + "--deploy", + required=True, + help="The deployed model.", + ) deployments_delete.set_defaults(func=Deployments.delete) - deployments_list = sub_parsers.add_parser('deployments.list') - deployments_list.add_argument('-s', - '--start_page', - type=int, - default=1, - help='Start of page, default 1') - deployments_list.add_argument('-p', - '--page_size', - type=int, - default=10, - help='The page size, default 10') + deployments_list = sub_parsers.add_parser("deployments.list") + deployments_list.add_argument( + "-s", + "--start_page", + type=int, + default=1, + help="Start of page, default 1", + ) + deployments_list.add_argument( + "-p", + "--page_size", + type=int, + default=10, + help="The page size, default 10", + ) deployments_list.set_defaults(func=Deployments.list) - deployments_scale = sub_parsers.add_parser('deployments.scale') - deployments_scale.add_argument('-d', - '--deployed_model', - type=str, - required=True, - help='The deployed model to scale') - deployments_scale.add_argument('-c', - '--capacity', - type=int, - required=True, - help='The target capacity') + deployments_scale = sub_parsers.add_parser("deployments.scale") + deployments_scale.add_argument( + "-d", + "--deployed_model", + type=str, + required=True, + help="The deployed model to scale", + ) + deployments_scale.add_argument( + "-c", + "--capacity", + type=int, + required=True, + help="The target capacity", + ) deployments_scale.set_defaults(func=Deployments.scale) args = parser.parse_args() @@ -684,5 +821,5 @@ def main(): args.func(args) -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/dashscope/client/base_api.py b/dashscope/client/base_api.py index ae3da62..feb1e75 100644 --- a/dashscope/client/base_api.py +++ b/dashscope/client/base_api.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio import collections @@ -11,38 +12,60 @@ from dashscope.api_entities.api_request_factory import _build_api_request from dashscope.api_entities.dashscope_response import DashScopeAPIResponse from dashscope.common.api_key import get_default_api_key -from dashscope.common.constants import (DEFAULT_REQUEST_TIMEOUT_SECONDS, - REPEATABLE_STATUS, - REQUEST_TIMEOUT_KEYWORD, - SSE_CONTENT_TYPE, TaskStatus, HTTPMethod) +from dashscope.common.constants import ( + DEFAULT_REQUEST_TIMEOUT_SECONDS, + REPEATABLE_STATUS, + REQUEST_TIMEOUT_KEYWORD, + SSE_CONTENT_TYPE, + TaskStatus, + HTTPMethod, +) from dashscope.common.error import InvalidParameter, InvalidTask, ModelRequired from dashscope.common.logging import logger -from dashscope.common.utils import (_handle_http_failed_response, - _handle_http_response, - _handle_http_stream_response, - default_headers, join_url) +from dashscope.common.utils import ( + _handle_http_failed_response, + _handle_http_response, + _handle_http_stream_response, + default_headers, + join_url, +) + class AsyncAioTaskGetMixin: @classmethod - async def _get(cls, - task_id: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: - base_url = kwargs.pop('base_address', None) - url = _normalization_url(base_url, 'tasks', task_id) + async def _get( + cls, + task_id: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: + base_url = kwargs.pop("base_address", None) + url = _normalization_url(base_url, "tasks", task_id) kwargs = cls._handle_kwargs(api_key, workspace, **kwargs) kwargs["base_address"] = url if not api_key: api_key = get_default_api_key() - request = _build_api_request("", "", "", - "", "", api_key=api_key, - is_service=False, **kwargs) + request = _build_api_request( + "", + "", + "", + "", + "", + api_key=api_key, + is_service=False, + **kwargs, + ) return await cls._handle_request(request) @classmethod - def _handle_kwargs(cls, api_key: str = None ,workspace: str = None, **kwargs): - custom_headers = kwargs.pop('headers', None) + def _handle_kwargs( + cls, + api_key: str = None, + workspace: str = None, + **kwargs, + ): + custom_headers = kwargs.pop("headers", None) headers = { **_workspace_header(workspace), **default_headers(api_key), @@ -54,11 +77,11 @@ def _handle_kwargs(cls, api_key: str = None ,workspace: str = None, **kwargs): } if workspace is not None: headers = { - 'X-DashScope-WorkSpace': workspace, - **kwargs.pop('headers', {}) + "X-DashScope-WorkSpace": workspace, + **kwargs.pop("headers", {}), } - kwargs['headers'] = headers - kwargs['http_method'] = HTTPMethod.GET + kwargs["headers"] = headers + kwargs["http_method"] = HTTPMethod.GET return kwargs @classmethod @@ -74,89 +97,106 @@ async def _handle_request(cls, request): else: return response + class BaseAsyncAioApi(AsyncAioTaskGetMixin): - """BaseApi, internal use only. + """BaseApi, internal use only.""" - """ @classmethod def _validate_params(cls, api_key, model): if api_key is None: api_key = get_default_api_key() if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") return api_key, model @classmethod - async def async_call(cls, - model: str, - input: object, - task_group: str, - task: str = None, - function: str = None, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + async def async_call( + cls, + model: str, + input: object, # pylint: disable=redefined-builtin + task_group: str, + task: str = None, + function: str = None, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: api_key, model = cls._validate_params(api_key, model) if workspace is not None: headers = { - 'X-DashScope-WorkSpace': workspace, - **kwargs.pop('headers', {}) + "X-DashScope-WorkSpace": workspace, + **kwargs.pop("headers", {}), } - kwargs['headers'] = headers - kwargs['async_request'] = True - request = _build_api_request(model=model, - input=input, - task_group=task_group, - task=task, - function=function, - api_key=api_key, - **kwargs) + kwargs["headers"] = headers + kwargs["async_request"] = True + request = _build_api_request( + model=model, + input=input, + task_group=task_group, + task=task, + function=function, + api_key=api_key, + **kwargs, + ) # call request service. return await request.aio_call() @classmethod - async def call(cls, - model: str, - input: object, - task_group: str, - task: str = None, - function: str = None, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + async def call( + cls, + model: str, + input: object, # pylint: disable=redefined-builtin + task_group: str, + task: str = None, + function: str = None, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: # call request service. - response = await BaseAsyncAioApi.async_call(model, input, task_group, task, - function, api_key, workspace, - **kwargs) - response = await BaseAsyncAioApi.wait(response, - api_key=api_key, - workspace=workspace, - **kwargs) + response = await BaseAsyncAioApi.async_call( + model, + input, + task_group, + task, + function, + api_key, + workspace, + **kwargs, + ) + response = await BaseAsyncAioApi.wait( + response, + api_key=api_key, + workspace=workspace, + **kwargs, + ) return response - @classmethod def _get_task_id(cls, task): if isinstance(task, str): task_id = task elif isinstance(task, DashScopeAPIResponse): if task.status_code == HTTPStatus.OK: - task_id = task.output['task_id'] + task_id = task.output["task_id"] else: - raise InvalidTask('Invalid task, task create failed: %s' % - task) + raise InvalidTask( + f"Invalid task, task create failed: {task}", + ) else: - raise InvalidParameter('Task invalid!') - if task_id is None or task_id == '': - raise InvalidParameter('Task id required!') + raise InvalidParameter("Task invalid!") + if task_id is None or task_id == "": + raise InvalidParameter("Task id required!") return task_id @classmethod - async def wait(cls, - task: Union[str, DashScopeAPIResponse], - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + async def wait( + cls, + task: Union[str, DashScopeAPIResponse], + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Wait for async task completion and return task result. Args: @@ -181,25 +221,36 @@ async def wait(cls, # (server side return immediately when ready) if wait_seconds < max_wait_seconds and step % increment_steps == 0: wait_seconds = min(wait_seconds * 2, max_wait_seconds) - rsp = await cls._get(task_id, api_key, workspace=workspace, **kwargs) + rsp = await cls._get( + task_id, + api_key, + workspace=workspace, + **kwargs, + ) if rsp.status_code == HTTPStatus.OK: if rsp.output is None: return rsp - task_status = rsp.output['task_status'] + task_status = rsp.output["task_status"] if task_status in [ - TaskStatus.FAILED, TaskStatus.CANCELED, - TaskStatus.SUCCEEDED, TaskStatus.UNKNOWN + TaskStatus.FAILED, + TaskStatus.CANCELED, + TaskStatus.SUCCEEDED, + TaskStatus.UNKNOWN, ]: return rsp else: - logger.info('The task %s is %s' % (task_id, task_status)) + logger.info("The task %s is %s", task_id, task_status) await asyncio.sleep(wait_seconds) # 异步等待 elif rsp.status_code in REPEATABLE_STATUS: - logger.warn( - ('Get task: %s temporary failure, \ - status_code: %s, code: %s message: %s, will try again.' - ) % (task_id, rsp.status_code, rsp.code, rsp.message)) + logger.warning( + "Get task: %s temporary failure, " + "status_code: %s, code: %s message: %s, will try again.", + task_id, + rsp.status_code, + rsp.code, + rsp.message, + ) await asyncio.sleep(wait_seconds) # 异步等待 else: return rsp @@ -223,30 +274,39 @@ async def cancel( DashScopeAPIResponse: The cancel result. """ task_id = cls._get_task_id(task) - base_url = kwargs.pop('base_address', None) - url = _normalization_url(base_url, 'tasks', task_id, 'cancel') + base_url = kwargs.pop("base_address", None) + url = _normalization_url(base_url, "tasks", task_id, "cancel") kwargs = cls._handle_kwargs(api_key, workspace, **kwargs) kwargs["base_address"] = url if not api_key: api_key = get_default_api_key() - request = _build_api_request("", "", "", - "", "",api_key=api_key, - is_service=False, **kwargs) + request = _build_api_request( + "", + "", + "", + "", + "", + api_key=api_key, + is_service=False, + **kwargs, + ) return await cls._handle_request(request) @classmethod - async def list(cls, - start_time: str = None, - end_time: str = None, - model_name: str = None, - api_key_id: str = None, - region: str = None, - status: str = None, - page_no: int = 1, - page_size: int = 10, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + async def list( + cls, + start_time: str = None, + end_time: str = None, + model_name: str = None, + api_key_id: str = None, + region: str = None, + status: str = None, + page_no: int = 1, + page_size: int = 10, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """List async tasks. Args: @@ -269,37 +329,46 @@ async def list(cls, Returns: DashScopeAPIResponse: The response data. """ - base_url = kwargs.pop('base_address', None) - url = _normalization_url(base_url, 'tasks') - params = {'page_no': page_no, 'page_size': page_size} + base_url = kwargs.pop("base_address", None) + url = _normalization_url(base_url, "tasks") + params = {"page_no": page_no, "page_size": page_size} if start_time is not None: - params['start_time'] = start_time + params["start_time"] = start_time if end_time is not None: - params['end_time'] = end_time + params["end_time"] = end_time if model_name is not None: - params['model_name'] = model_name + params["model_name"] = model_name if api_key_id is not None: - params['api_key_id'] = api_key_id + params["api_key_id"] = api_key_id if region is not None: - params['region'] = region + params["region"] = region if status is not None: - params['status'] = status + params["status"] = status kwargs = cls._handle_kwargs(api_key, workspace, **kwargs) kwargs["base_address"] = url if not api_key: api_key = get_default_api_key() - request = _build_api_request(model_name, "", "", - "", "", api_key=api_key, - is_service=False, extra_url_parameters=params, - **kwargs) + request = _build_api_request( + model_name, + "", + "", + "", + "", + api_key=api_key, + is_service=False, + extra_url_parameters=params, + **kwargs, + ) return await cls._handle_request(request) @classmethod - async def fetch(cls, - task: Union[str, DashScopeAPIResponse], - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + async def fetch( + cls, + task: Union[str, DashScopeAPIResponse], + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Query async task status. Args: @@ -315,27 +384,28 @@ async def fetch(cls, class BaseAioApi: - """BaseApi, internal use only. + """BaseApi, internal use only.""" - """ @classmethod def _validate_params(cls, api_key, model): if api_key is None: api_key = get_default_api_key() if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") return api_key, model @classmethod - async def call(cls, - model: str, - input: object, - task_group: str, - task: str = None, - function: str = None, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + async def call( + cls, + model: str, + input: object, # pylint: disable=redefined-builtin + task_group: str, + task: str = None, + function: str = None, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Call service and get result. Args: @@ -362,43 +432,46 @@ async def call(cls, api_key, model = BaseAioApi._validate_params(api_key, model) if workspace is not None: headers = { - 'X-DashScope-WorkSpace': workspace, - **kwargs.pop('headers', {}) + "X-DashScope-WorkSpace": workspace, + **kwargs.pop("headers", {}), } - kwargs['headers'] = headers - request = _build_api_request(model=model, - input=input, - task_group=task_group, - task=task, - function=function, - api_key=api_key, - **kwargs) + kwargs["headers"] = headers + request = _build_api_request( + model=model, + input=input, + task_group=task_group, + task=task, + function=function, + api_key=api_key, + **kwargs, + ) # call request service. return await request.aio_call() -class BaseApi(): - """BaseApi, internal use only. +class BaseApi: + """BaseApi, internal use only.""" - """ @classmethod def _validate_params(cls, api_key, model): if api_key is None: api_key = get_default_api_key() if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") return api_key, model @classmethod - def call(cls, - model: str, - input: object, - task_group: str, - task: str = None, - function: str = None, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def call( + cls, + model: str, + input: object, # pylint: disable=redefined-builtin + task_group: str, + task: str = None, + function: str = None, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Call service and get result. Args: @@ -425,24 +498,26 @@ def call(cls, api_key, model = BaseApi._validate_params(api_key, model) if workspace is not None: headers = { - 'X-DashScope-WorkSpace': workspace, - **kwargs.pop('headers', {}) + "X-DashScope-WorkSpace": workspace, + **kwargs.pop("headers", {}), } - kwargs['headers'] = headers - request = _build_api_request(model=model, - input=input, - task_group=task_group, - task=task, - function=function, - api_key=api_key, - **kwargs) + kwargs["headers"] = headers + request = _build_api_request( + model=model, + input=input, + task_group=task_group, + task=task, + function=function, + api_key=api_key, + **kwargs, + ) # call request service. return request.call() def _workspace_header(workspace) -> Dict: if workspace is not None: - headers = {'X-DashScope-WorkSpace': workspace} + headers = {"X-DashScope-WorkSpace": workspace} else: headers = {} return headers @@ -456,16 +531,18 @@ def _normalization_url(base_address, *args): return join_url(url, *args) -class AsyncTaskGetMixin(): +class AsyncTaskGetMixin: @classmethod - def _get(cls, - task_id: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: - base_url = kwargs.pop('base_address', None) - status_url = _normalization_url(base_url, 'tasks', task_id) - custom_headers = kwargs.pop('headers', None) + def _get( + cls, + task_id: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: + base_url = kwargs.pop("base_address", None) + status_url = _normalization_url(base_url, "tasks", task_id) + custom_headers = kwargs.pop("headers", None) headers = { **_workspace_header(workspace), **default_headers(api_key), @@ -476,41 +553,47 @@ def _get(cls, **headers, } with requests.Session() as session: - logger.debug('Starting request: %s' % status_url) - response = session.get(status_url, - headers=headers, - timeout=DEFAULT_REQUEST_TIMEOUT_SECONDS) - logger.debug('Starting processing response: %s' % status_url) + logger.debug("Starting request: %s", status_url) + response = session.get( + status_url, + headers=headers, + timeout=DEFAULT_REQUEST_TIMEOUT_SECONDS, + ) + logger.debug("Starting processing response: %s", status_url) return _handle_http_response(response) class BaseAsyncApi(AsyncTaskGetMixin): - """BaseAsyncApi,for async task, internal use only. + """BaseAsyncApi,for async task, internal use only.""" - """ @classmethod def _validate_params(cls, api_key, model): if api_key is None: api_key = get_default_api_key() if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") return api_key, model @classmethod - def call(cls, - *args, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: - """Call service and get result. - """ - task_response = cls.async_call(*args, - api_key=api_key, - workspace=workspace, - **kwargs) - response = cls.wait(task_response, - api_key=api_key, - workspace=workspace) + def call( + cls, + *args, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: + """Call service and get result.""" + task_response = cls.async_call( # type: ignore[misc] + *args, + api_key=api_key, + workspace=workspace, + **kwargs, + ) + response = cls.wait( + task_response, + api_key=api_key, + workspace=workspace, + ) return response @classmethod @@ -519,14 +602,15 @@ def _get_task_id(cls, task): task_id = task elif isinstance(task, DashScopeAPIResponse): if task.status_code == HTTPStatus.OK: - task_id = task.output['task_id'] + task_id = task.output["task_id"] else: - raise InvalidTask('Invalid task, task create failed: %s' % - task) + raise InvalidTask( + f"Invalid task, task create failed: {task}", + ) else: - raise InvalidParameter('Task invalid!') - if task_id is None or task_id == '': - raise InvalidParameter('Task id required!') + raise InvalidParameter("Task invalid!") + if task_id is None or task_id == "": + raise InvalidParameter("Task id required!") return task_id @classmethod @@ -548,29 +632,33 @@ def cancel( DashScopeAPIResponse: The cancel result. """ task_id = cls._get_task_id(task) - base_url = kwargs.pop('base_address', None) - url = _normalization_url(base_url, 'tasks', task_id, 'cancel') + base_url = kwargs.pop("base_address", None) + url = _normalization_url(base_url, "tasks", task_id, "cancel") with requests.Session() as session: - response = session.post(url, - headers={ - **_workspace_header(workspace), - **default_headers(api_key), - }) + response = session.post( + url, + headers={ + **_workspace_header(workspace), + **default_headers(api_key), + }, + ) return _handle_http_response(response) @classmethod - def list(cls, - start_time: str = None, - end_time: str = None, - model_name: str = None, - api_key_id: str = None, - region: str = None, - status: str = None, - page_no: int = 1, - page_size: int = 10, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def list( + cls, + start_time: str = None, + end_time: str = None, + model_name: str = None, + api_key_id: str = None, + region: str = None, + status: str = None, + page_no: int = 1, + page_size: int = 10, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """List async tasks. Args: @@ -593,41 +681,45 @@ def list(cls, Returns: DashScopeAPIResponse: The response data. """ - base_url = kwargs.pop('base_address', None) - url = _normalization_url(base_url, 'tasks') - params = {'page_no': page_no, 'page_size': page_size} + base_url = kwargs.pop("base_address", None) + url = _normalization_url(base_url, "tasks") + params = {"page_no": page_no, "page_size": page_size} if start_time is not None: - params['start_time'] = start_time + params["start_time"] = start_time if end_time is not None: - params['end_time'] = end_time + params["end_time"] = end_time if model_name is not None: - params['model_name'] = model_name + params["model_name"] = model_name if api_key_id is not None: - params['api_key_id'] = api_key_id + params["api_key_id"] = api_key_id if region is not None: - params['region'] = region + params["region"] = region if status is not None: - params['status'] = status + params["status"] = status with requests.Session() as session: - response = session.get(url, - params=params, - headers={ - **_workspace_header(workspace), - **default_headers(api_key), - }) + response = session.get( + url, + params=params, + headers={ + **_workspace_header(workspace), + **default_headers(api_key), + }, + ) if response.status_code == HTTPStatus.OK: json_content = response.json() - request_id = '' - if 'request_id' in json_content: - request_id = json_content['request_id'] - json_content.pop('request_id') - return DashScopeAPIResponse(request_id=request_id, - status_code=response.status_code, - code=None, - output=json_content, - usage=None, - message='') + request_id = "" + if "request_id" in json_content: + request_id = json_content["request_id"] + json_content.pop("request_id") + return DashScopeAPIResponse( + request_id=request_id, + status_code=response.status_code, + code=None, # type: ignore[arg-type] + output=json_content, + usage=None, + message="", + ) else: return _handle_http_failed_response(response) @@ -653,11 +745,13 @@ def fetch( return cls._get(task_id, api_key, workspace, **kwargs) @classmethod - def wait(cls, - task: Union[str, DashScopeAPIResponse], - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def wait( + cls, + task: Union[str, DashScopeAPIResponse], + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Wait for async task completion and return task result. Args: @@ -688,34 +782,42 @@ def wait(cls, if rsp.output is None: return rsp - task_status = rsp.output['task_status'] + task_status = rsp.output["task_status"] if task_status in [ - TaskStatus.FAILED, TaskStatus.CANCELED, - TaskStatus.SUCCEEDED, TaskStatus.UNKNOWN + TaskStatus.FAILED, + TaskStatus.CANCELED, + TaskStatus.SUCCEEDED, + TaskStatus.UNKNOWN, ]: return rsp else: - logger.info('The task %s is %s' % (task_id, task_status)) + logger.info("The task %s is %s", task_id, task_status) time.sleep(wait_seconds) elif rsp.status_code in REPEATABLE_STATUS: - logger.warn( - ('Get task: %s temporary failure, \ - status_code: %s, code: %s message: %s, will try again.' - ) % (task_id, rsp.status_code, rsp.code, rsp.message)) + logger.warning( + "Get task: %s temporary failure, " + "status_code: %s, code: %s message: %s, will try again.", + task_id, + rsp.status_code, + rsp.code, + rsp.message, + ) time.sleep(wait_seconds) else: return rsp @classmethod - def async_call(cls, - model: str, - input: object, - task_group: str, - task: str = None, - function: str = None, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def async_call( + cls, + model: str, + input: object, # pylint: disable=redefined-builtin + task_group: str, + task: str = None, + function: str = None, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Call async service return async task information. Args: @@ -733,47 +835,63 @@ def async_call(cls, which contains the task id, you can use the task id to query the task status. """ - is_stream = kwargs.pop('stream', None) # async api not support stream. + is_stream = kwargs.pop("stream", None) # async api not support stream. if is_stream: - logger.warn('async_call do not support stream argument') - api_key, model = BaseApi._validate_params(api_key, model) + logger.warning("async_call do not support stream argument") + # Access BaseApi's validation method for consistency + ( + api_key, + model, + ) = BaseApi._validate_params( # pylint: disable=protected-access + api_key, + model, + ) if workspace is not None: headers = { - 'X-DashScope-WorkSpace': workspace, - **kwargs.pop('headers', {}) + "X-DashScope-WorkSpace": workspace, + **kwargs.pop("headers", {}), } - kwargs['headers'] = headers - request = _build_api_request(model=model, - input=input, - task_group=task_group, - task=task, - function=function, - api_key=api_key, - async_request=True, - query=False, - **kwargs) + kwargs["headers"] = headers + request = _build_api_request( + model=model, + input=input, + task_group=task_group, + task=task, + function=function, + api_key=api_key, + async_request=True, + query=False, + **kwargs, + ) return request.call() -def _get(url, - params={}, - api_key=None, - flattened_output=False, - workspace: str = None, - **kwargs) -> Union[DashScopeAPIResponse, Dict]: - timeout = kwargs.pop(REQUEST_TIMEOUT_KEYWORD, - DEFAULT_REQUEST_TIMEOUT_SECONDS) +# pylint: disable=dangerous-default-value +def _get( + url, + params={}, + api_key=None, + flattened_output=False, + workspace: str = None, + **kwargs, +) -> Union[DashScopeAPIResponse, Dict]: + timeout = kwargs.pop( + REQUEST_TIMEOUT_KEYWORD, + DEFAULT_REQUEST_TIMEOUT_SECONDS, + ) with requests.Session() as session: - logger.debug('Starting request: %s' % url) - response = session.get(url, - headers={ - **_workspace_header(workspace), - **default_headers(api_key), - **kwargs.pop('headers', {}) - }, - params=params, - timeout=timeout) - logger.debug('Starting processing response: %s' % url) + logger.debug("Starting request: %s", url) + response = session.get( + url, + headers={ + **_workspace_header(workspace), + **default_headers(api_key), + **kwargs.pop("headers", {}), + }, + params=params, + timeout=timeout, + ) + logger.debug("Starting processing response: %s", url) return _handle_http_response(response, flattened_output) @@ -789,60 +907,66 @@ def _get_url(custom_base_url, default_path, path): return url -class ListObjectMixin(): +class ListObjectMixin: @classmethod - def list(cls, - limit: int = None, - order: str = None, - after: str = None, - before: str = None, - path: str = None, - workspace: str = None, - api_key: str = None, - **kwargs) -> Any: + def list( + cls, + limit: int = None, + order: str = None, + after: str = None, + before: str = None, + path: str = None, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> Any: """List object Args: limit (int, optional): How many object to list. Defaults to None. order (str, optional): The order of result. Defaults to None. - after (str, optional): The id of the object begin. Defaults to None. + after (str, optional): The id of the object begin. Defaults to None. # noqa: E501 before (str, optional): The if of the object end. Defaults to None. path (str, optional): The request path. Defaults to None. - workspace (str, optional): The DashScope workspace id. Defaults to None. + workspace (str, optional): The DashScope workspace id. Defaults to None. # noqa: E501 api_key (str, optional): The DashScope api_key. Defaults to None. Returns: Any: The object list. """ - custom_base_url = kwargs.pop('base_address', None) + custom_base_url = kwargs.pop("base_address", None) url = _get_url(custom_base_url, cls.SUB_PATH.lower(), path) params = {} if limit is not None: if limit < 0: - raise InvalidParameter('limit should >= 0') - params['limit'] = limit + raise InvalidParameter("limit should >= 0") + params["limit"] = limit if order is not None: - params['order'] = order + params["order"] = order if after is not None: - params['after'] = after + params["after"] = after if before is not None: - params['before'] = before - return _get(url, - params=params, - api_key=api_key, - workspace=workspace, - **kwargs) + params["before"] = before + return _get( + url, + params=params, + api_key=api_key, + workspace=workspace, + **kwargs, + ) -class ListMixin(): +class ListMixin: @classmethod - def list(cls, - page_no=1, - page_size=10, - api_key: str = None, - path: str = None, - workspace: str = None, - **kwargs) -> Union[DashScopeAPIResponse, Dict]: + def list( + cls, + page_no=1, + page_size=10, + api_key: str = None, + path: str = None, + workspace: str = None, + **kwargs, + ) -> Union[DashScopeAPIResponse, Dict]: """list objects Args: @@ -855,26 +979,30 @@ def list(cls, Returns: DashScopeAPIResponse: The object list in output. """ - custom_base_url = kwargs.pop('base_address', None) + custom_base_url = kwargs.pop("base_address", None) url = _get_url(custom_base_url, cls.SUB_PATH.lower(), path) - params = {'page_no': page_no, 'page_size': page_size} - return _get(url, - params=params, - api_key=api_key, - workspace=workspace, - **kwargs) + params = {"page_no": page_no, "page_size": page_size} + return _get( + url, + params=params, + api_key=api_key, + workspace=workspace, + **kwargs, + ) -class LogMixin(): +class LogMixin: @classmethod - def logs(cls, - job_id: str, - offset=1, - line=1000, - api_key: str = None, - path: str = None, - workspace: str = None, - **kwargs) -> Union[DashScopeAPIResponse, Dict]: + def logs( # pylint: disable=unused-argument + cls, + job_id: str, + offset=1, + line=1000, + api_key: str = None, + path: str = None, + workspace: str = None, + **kwargs, + ) -> Union[DashScopeAPIResponse, Dict]: """Get log of the job. Args: @@ -886,29 +1014,38 @@ def logs(cls, Returns: DashScopeAPIResponse: The response """ - custom_base_url = kwargs.pop('base_address', None) + custom_base_url = kwargs.pop("base_address", None) if not custom_base_url: - url = join_url(dashscope.base_http_api_url, cls.SUB_PATH.lower(), - job_id, 'logs') + url = join_url( + dashscope.base_http_api_url, + cls.SUB_PATH.lower(), + job_id, + "logs", + ) else: url = custom_base_url - params = {'offset': offset, 'line': line} - return _get(url, - params=params, - api_key=api_key, - workspace=workspace, - **kwargs) + params = {"offset": offset, "line": line} + return _get( + url, + params=params, + api_key=api_key, + workspace=workspace, + **kwargs, + ) -class GetMixin(): +class GetMixin: @classmethod - def get(cls, - target, - api_key: str = None, - params: dict = {}, - path: str = None, - workspace: str = None, - **kwargs) -> Union[DashScopeAPIResponse, Dict]: + # pylint: disable=dangerous-default-value + def get( + cls, + target, + api_key: str = None, + params: dict = {}, + path: str = None, + workspace: str = None, + **kwargs, + ) -> Union[DashScopeAPIResponse, Dict]: """Get object information. Args: @@ -919,7 +1056,7 @@ def get(cls, Returns: DashScopeAPIResponse: The object information in output. """ - custom_base_url = kwargs.pop('base_address', None) + custom_base_url = kwargs.pop("base_address", None) if custom_base_url: base_url = custom_base_url else: @@ -929,23 +1066,27 @@ def get(cls, url = join_url(base_url, path) else: url = join_url(base_url, cls.SUB_PATH.lower(), target) - flattened_output = kwargs.pop('flattened_output', False) - return _get(url, - api_key=api_key, - params=params, - flattened_output=flattened_output, - workspace=workspace, - **kwargs) - - -class GetStatusMixin(): + flattened_output = kwargs.pop("flattened_output", False) + return _get( + url, + api_key=api_key, + params=params, + flattened_output=flattened_output, + workspace=workspace, + **kwargs, + ) + + +class GetStatusMixin: @classmethod - def get(cls, - target, - api_key: str = None, - path: str = None, - workspace: str = None, - **kwargs) -> Union[DashScopeAPIResponse, Dict]: + def get( + cls, + target, + api_key: str = None, + path: str = None, + workspace: str = None, + **kwargs, + ) -> Union[DashScopeAPIResponse, Dict]: """Get object information. Args: @@ -956,7 +1097,7 @@ def get(cls, Returns: DashScopeAPIResponse: The object information in output. """ - custom_base_url = kwargs.pop('base_address', None) + custom_base_url = kwargs.pop("base_address", None) if custom_base_url: base_url = custom_base_url else: @@ -965,23 +1106,27 @@ def get(cls, url = join_url(base_url, path) else: url = join_url(base_url, cls.SUB_PATH.lower(), target) - flattened_output = kwargs.pop('flattened_output', False) - return _get(url, - api_key=api_key, - flattened_output=flattened_output, - workspace=workspace, - **kwargs) + flattened_output = kwargs.pop("flattened_output", False) + return _get( + url, + api_key=api_key, + flattened_output=flattened_output, + workspace=workspace, + **kwargs, + ) -class DeleteMixin(): +class DeleteMixin: @classmethod - def delete(cls, - target: str, - api_key: str = None, - path: str = None, - workspace: str = None, - flattened_output=False, - **kwargs) -> Union[DashScopeAPIResponse, Dict]: + def delete( + cls, + target: str, + api_key: str = None, + path: str = None, + workspace: str = None, + flattened_output=False, + **kwargs, + ) -> Union[DashScopeAPIResponse, Dict]: """Delete object. Args: @@ -992,7 +1137,7 @@ def delete(cls, Returns: DashScopeAPIResponse: The delete result. """ - custom_base_url = kwargs.pop('base_address', None) + custom_base_url = kwargs.pop("base_address", None) if custom_base_url: base_url = custom_base_url else: @@ -1001,30 +1146,36 @@ def delete(cls, url = join_url(base_url, path) else: url = join_url(base_url, cls.SUB_PATH.lower(), target) - timeout = kwargs.pop(REQUEST_TIMEOUT_KEYWORD, - DEFAULT_REQUEST_TIMEOUT_SECONDS) + timeout = kwargs.pop( + REQUEST_TIMEOUT_KEYWORD, + DEFAULT_REQUEST_TIMEOUT_SECONDS, + ) with requests.Session() as session: - logger.debug('Starting request: %s' % url) - response = session.delete(url, - headers={ - **_workspace_header(workspace), - **default_headers(api_key), - **kwargs.pop('headers', {}) - }, - timeout=timeout) - logger.debug('Starting processing response: %s' % url) + logger.debug("Starting request: %s", url) + response = session.delete( + url, + headers={ + **_workspace_header(workspace), + **default_headers(api_key), + **kwargs.pop("headers", {}), + }, + timeout=timeout, + ) + logger.debug("Starting processing response: %s", url) return _handle_http_response(response, flattened_output) -class CreateMixin(): +class CreateMixin: @classmethod - def call(cls, - data: object, - api_key: str = None, - path: str = None, - stream: bool = False, - workspace: str = None, - **kwargs) -> Union[DashScopeAPIResponse, Dict]: + def call( + cls, + data: object, + api_key: str = None, + path: str = None, + stream: bool = False, + workspace: str = None, + **kwargs, + ) -> Union[DashScopeAPIResponse, Dict]: """Create a object Args: @@ -1035,26 +1186,33 @@ def call(cls, Returns: DashScopeAPIResponse: The created object in output. """ - url = _get_url(kwargs.pop('base_address', None), cls.SUB_PATH.lower(), - path) - timeout = kwargs.pop(REQUEST_TIMEOUT_KEYWORD, - DEFAULT_REQUEST_TIMEOUT_SECONDS) - flattened_output = kwargs.pop('flattened_output', False) + url = _get_url( + kwargs.pop("base_address", None), + cls.SUB_PATH.lower(), + path, + ) + timeout = kwargs.pop( + REQUEST_TIMEOUT_KEYWORD, + DEFAULT_REQUEST_TIMEOUT_SECONDS, + ) + flattened_output = kwargs.pop("flattened_output", False) with requests.Session() as session: - logger.debug('Starting request: %s' % url) - response = session.post(url, - json=data, - stream=stream, - headers={ - **_workspace_header(workspace), - **default_headers(api_key), - **kwargs.pop('headers', {}) - }, - timeout=timeout) - logger.debug('Starting processing response: %s' % url) + logger.debug("Starting request: %s", url) + response = session.post( + url, + json=data, + stream=stream, + headers={ + **_workspace_header(workspace), + **default_headers(api_key), + **kwargs.pop("headers", {}), + }, + timeout=timeout, + ) + logger.debug("Starting processing response: %s", url) response = _handle_http_stream_response(response, flattened_output) if stream: - return (item for item in response) + return (item for item in response) # type: ignore else: _, output = next(response) try: @@ -1064,16 +1222,18 @@ def call(cls, return output -class UpdateMixin(): +class UpdateMixin: @classmethod - def update(cls, - target: str, - json: object, - api_key: str = None, - path: str = None, - workspace: str = None, - method: str = 'patch', - **kwargs) -> Union[DashScopeAPIResponse, Dict]: + def update( + cls, + target: str, + json: object, + api_key: str = None, + path: str = None, + workspace: str = None, + method: str = "patch", + **kwargs, + ) -> Union[DashScopeAPIResponse, Dict]: """Async update a object Args: @@ -1085,7 +1245,7 @@ def update(cls, Returns: DashScopeAPIResponse: The updated object information in output. """ - custom_base_url = kwargs.pop('base_address', None) + custom_base_url = kwargs.pop("base_address", None) if custom_base_url: base_url = custom_base_url else: @@ -1094,42 +1254,50 @@ def update(cls, url = join_url(base_url, path) else: url = join_url(base_url, cls.SUB_PATH.lower(), target) - timeout = kwargs.pop(REQUEST_TIMEOUT_KEYWORD, - DEFAULT_REQUEST_TIMEOUT_SECONDS) - flattened_output = kwargs.pop('flattened_output', False) + timeout = kwargs.pop( + REQUEST_TIMEOUT_KEYWORD, + DEFAULT_REQUEST_TIMEOUT_SECONDS, + ) + flattened_output = kwargs.pop("flattened_output", False) with requests.Session() as session: - logger.debug('Starting request: %s' % url) - if method == 'post': - response = session.post(url, - json=json, - headers={ - **_workspace_header(workspace), - **default_headers(api_key), - **kwargs.pop('headers', {}) - }, - timeout=timeout) + logger.debug("Starting request: %s", url) + if method == "post": + response = session.post( + url, + json=json, + headers={ + **_workspace_header(workspace), + **default_headers(api_key), + **kwargs.pop("headers", {}), + }, + timeout=timeout, + ) else: - response = session.patch(url, - json=json, - headers={ - **_workspace_header(workspace), - **default_headers(api_key), - **kwargs.pop('headers', {}) - }, - timeout=timeout) - logger.debug('Starting processing response: %s' % url) + response = session.patch( + url, + json=json, + headers={ + **_workspace_header(workspace), + **default_headers(api_key), + **kwargs.pop("headers", {}), + }, + timeout=timeout, + ) + logger.debug("Starting processing response: %s", url) return _handle_http_response(response, flattened_output) -class PutMixin(): +class PutMixin: @classmethod - def put(cls, - target: str, - json: object, - path: str = None, - api_key: str = None, - workspace: str = None, - **kwargs) -> Union[DashScopeAPIResponse, Dict]: + def put( + cls, + target: str, + json: object, + path: str = None, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> Union[DashScopeAPIResponse, Dict]: """Async update a object Args: @@ -1141,7 +1309,7 @@ def put(cls, Returns: DashScopeAPIResponse: The updated object information in output. """ - custom_base_url = kwargs.pop('base_address', None) + custom_base_url = kwargs.pop("base_address", None) if custom_base_url: base_url = custom_base_url else: @@ -1150,31 +1318,37 @@ def put(cls, url = join_url(base_url, cls.SUB_PATH.lower(), target) else: url = join_url(base_url, path) - timeout = kwargs.pop(REQUEST_TIMEOUT_KEYWORD, - DEFAULT_REQUEST_TIMEOUT_SECONDS) + timeout = kwargs.pop( + REQUEST_TIMEOUT_KEYWORD, + DEFAULT_REQUEST_TIMEOUT_SECONDS, + ) with requests.Session() as session: - logger.debug('Starting request: %s' % url) - response = session.put(url, - json=json, - headers={ - **_workspace_header(workspace), - **default_headers(api_key), - **kwargs.pop('headers', {}) - }, - timeout=timeout) - logger.debug('Starting processing response: %s' % url) + logger.debug("Starting request: %s", url) + response = session.put( + url, + json=json, + headers={ + **_workspace_header(workspace), + **default_headers(api_key), + **kwargs.pop("headers", {}), + }, + timeout=timeout, + ) + logger.debug("Starting processing response: %s", url) return _handle_http_response(response) -class FileUploadMixin(): +class FileUploadMixin: @classmethod - def upload(cls, - files: list, - descriptions: List[str] = None, - params: dict = None, - api_key: str = None, - workspace: str = None, - **kwargs) -> Union[DashScopeAPIResponse, Dict]: + def upload( # pylint: disable=unused-argument + cls, + files: list, + descriptions: List[str] = None, + params: dict = None, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> Union[DashScopeAPIResponse, Dict]: """Upload files Args: @@ -1187,7 +1361,7 @@ def upload(cls, Returns: DashScopeAPIResponse: The uploaded file information in the output. """ - custom_base_url = kwargs.pop('base_address', None) + custom_base_url = kwargs.pop("base_address", None) if custom_base_url: base_url = custom_base_url else: @@ -1195,32 +1369,38 @@ def upload(cls, url = join_url(base_url, cls.SUB_PATH.lower()) js = None if descriptions: - js = {'descriptions': descriptions} - timeout = kwargs.pop(REQUEST_TIMEOUT_KEYWORD, - DEFAULT_REQUEST_TIMEOUT_SECONDS) + js = {"descriptions": descriptions} + timeout = kwargs.pop( + REQUEST_TIMEOUT_KEYWORD, + DEFAULT_REQUEST_TIMEOUT_SECONDS, + ) with requests.Session() as session: - logger.debug('Starting request: %s' % url) - response = session.post(url, - data=js, - headers={ - **_workspace_header(workspace), - **default_headers(api_key), - **kwargs.pop('headers', {}) - }, - files=files, - timeout=timeout) - logger.debug('Starting processing response: %s' % url) + logger.debug("Starting request: %s", url) + response = session.post( + url, + data=js, + headers={ + **_workspace_header(workspace), + **default_headers(api_key), + **kwargs.pop("headers", {}), + }, + files=files, + timeout=timeout, + ) + logger.debug("Starting processing response: %s", url) return _handle_http_response(response) -class CancelMixin(): +class CancelMixin: @classmethod - def cancel(cls, - target: str, - path: str = None, - api_key: str = None, - workspace: str = None, - **kwargs) -> Union[DashScopeAPIResponse, Dict]: + def cancel( + cls, + target: str, + path: str = None, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> Union[DashScopeAPIResponse, Dict]: """Cancel a job. Args: @@ -1231,32 +1411,36 @@ def cancel(cls, Returns: DashScopeAPIResponse: The cancel result. """ - custom_base_url = kwargs.pop('base_address', None) + custom_base_url = kwargs.pop("base_address", None) if custom_base_url: base_url = custom_base_url else: base_url = dashscope.base_http_api_url if not path: - url = join_url(base_url, cls.SUB_PATH.lower(), target, 'cancel') + url = join_url(base_url, cls.SUB_PATH.lower(), target, "cancel") else: url = join_url(base_url, path) - timeout = kwargs.pop(REQUEST_TIMEOUT_KEYWORD, - DEFAULT_REQUEST_TIMEOUT_SECONDS) - flattened_output = kwargs.pop('flattened_output', False) + timeout = kwargs.pop( + REQUEST_TIMEOUT_KEYWORD, + DEFAULT_REQUEST_TIMEOUT_SECONDS, + ) + flattened_output = kwargs.pop("flattened_output", False) with requests.Session() as session: - logger.debug('Starting request: %s' % url) - response = session.post(url, - headers={ - **_workspace_header(workspace), - **default_headers(api_key), - **kwargs.pop('headers', {}) - }, - timeout=timeout) - logger.debug('Starting processing response: %s' % url) + logger.debug("Starting request: %s", url) + response = session.post( + url, + headers={ + **_workspace_header(workspace), + **default_headers(api_key), + **kwargs.pop("headers", {}), + }, + timeout=timeout, + ) + logger.debug("Starting processing response: %s", url) return _handle_http_response(response, flattened_output) -class StreamEventMixin(): +class StreamEventMixin: @classmethod def _handle_stream(cls, response: requests.Response): # TODO define done message. @@ -1264,15 +1448,15 @@ def _handle_stream(cls, response: requests.Response): status_code = HTTPStatus.INTERNAL_SERVER_ERROR for line in response.iter_lines(): if line: - line = line.decode('utf8') - line = line.rstrip('\n').rstrip('\r') - if line.startswith('event:error'): + line = line.decode("utf8") + line = line.rstrip("\n").rstrip("\r") + if line.startswith("event:error"): is_error = True - elif line.startswith('status:'): - status_code = line[len('status:'):] + elif line.startswith("status:"): + status_code = line[len("status:") :] status_code = int(status_code.strip()) - elif line.startswith('data:'): - line = line[len('data:'):] + elif line.startswith("data:"): + line = line[len("data:") :] yield (is_error, status_code, line) if is_error: break @@ -1281,40 +1465,53 @@ def _handle_stream(cls, response: requests.Response): @classmethod def _handle_response(cls, response: requests.Response): - request_id = '' - if (response.status_code == HTTPStatus.OK - and SSE_CONTENT_TYPE in response.headers.get( - 'content-type', '')): + request_id = "" + if ( + response.status_code == HTTPStatus.OK + and SSE_CONTENT_TYPE + in response.headers.get( + "content-type", + "", + ) + ): for is_error, status_code, data in cls._handle_stream(response): if is_error: - yield DashScopeAPIResponse(request_id=request_id, - status_code=status_code, - output=None, - code='', - message='') # noqa E501 + yield DashScopeAPIResponse( + request_id=request_id, + status_code=status_code, + output=None, + code="", + message="", + ) # noqa E501 else: - yield DashScopeAPIResponse(request_id=request_id, - status_code=HTTPStatus.OK, - output=data, - usage=None) + yield DashScopeAPIResponse( + request_id=request_id, + status_code=HTTPStatus.OK, + output=data, + usage=None, + ) elif response.status_code == HTTPStatus.OK: json_content = response.json() - request_id = '' - if 'request_id' in json_content: - request_id = json_content['request_id'] - yield DashScopeAPIResponse(request_id=request_id, - status_code=HTTPStatus.OK, - output=json_content, - usage=None) + request_id = "" + if "request_id" in json_content: + request_id = json_content["request_id"] + yield DashScopeAPIResponse( + request_id=request_id, + status_code=HTTPStatus.OK, + output=json_content, + usage=None, + ) else: yield _handle_http_failed_response(response) @classmethod - def stream_events(cls, - target, - api_key: str = None, - workspace: str = None, - **kwargs) -> Iterator[DashScopeAPIResponse]: + def stream_events( + cls, + target, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> Iterator[DashScopeAPIResponse]: """Get job log. Args: @@ -1325,24 +1522,28 @@ def stream_events(cls, Returns: DashScopeAPIResponse: The target outputs. """ - custom_base_url = kwargs.pop('base_address', None) + custom_base_url = kwargs.pop("base_address", None) if custom_base_url: base_url = custom_base_url else: base_url = dashscope.base_http_api_url - url = join_url(base_url, cls.SUB_PATH.lower(), target, 'stream') - timeout = kwargs.pop(REQUEST_TIMEOUT_KEYWORD, - DEFAULT_REQUEST_TIMEOUT_SECONDS) + url = join_url(base_url, cls.SUB_PATH.lower(), target, "stream") + timeout = kwargs.pop( + REQUEST_TIMEOUT_KEYWORD, + DEFAULT_REQUEST_TIMEOUT_SECONDS, + ) with requests.Session() as session: - logger.debug('Starting request: %s' % url) - response = session.get(url, - headers={ - **_workspace_header(workspace), - **default_headers(api_key), - **kwargs.pop('headers', {}) - }, - stream=True, - timeout=timeout) - logger.debug('Starting processing response: %s' % url) + logger.debug("Starting request: %s", url) + response = session.get( + url, + headers={ + **_workspace_header(workspace), + **default_headers(api_key), + **kwargs.pop("headers", {}), + }, + stream=True, + timeout=timeout, + ) + logger.debug("Starting processing response: %s", url) for rsp in cls._handle_response(response): yield rsp diff --git a/dashscope/common/api_key.py b/dashscope/common/api_key.py index e47c930..d083e84 100644 --- a/dashscope/common/api_key.py +++ b/dashscope/common/api_key.py @@ -1,11 +1,14 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import os from typing import Optional import dashscope -from dashscope.common.constants import (DEFAULT_DASHSCOPE_API_KEY_FILE_PATH, - DEFAULT_DASHSCOPE_CACHE_PATH) +from dashscope.common.constants import ( + DEFAULT_DASHSCOPE_API_KEY_FILE_PATH, + DEFAULT_DASHSCOPE_CACHE_PATH, +) from dashscope.common.error import AuthenticationError @@ -15,31 +18,38 @@ def get_default_api_key(): return dashscope.api_key elif dashscope.api_key_file_path: # user set environment variable DASHSCOPE_API_KEY_FILE_PATH - with open(dashscope.api_key_file_path, 'rt', - encoding='utf-8') as f: # open with text mode. + with open( + dashscope.api_key_file_path, + "rt", + encoding="utf-8", + ) as f: # open with text mode. return f.read().strip() else: # Find the api key from default key file. if os.path.exists(DEFAULT_DASHSCOPE_API_KEY_FILE_PATH): - with open(DEFAULT_DASHSCOPE_API_KEY_FILE_PATH, - 'rt', - encoding='utf-8') as f: + with open( + DEFAULT_DASHSCOPE_API_KEY_FILE_PATH, + "rt", + encoding="utf-8", + ) as f: return f.read().strip() raise AuthenticationError( - 'No api key provided. You can set by dashscope.api_key = your_api_key in code, ' # noqa: E501 - 'or you can set it via environment variable DASHSCOPE_API_KEY= your_api_key. ' # noqa: E501 - 'You can store your api key to a file, and use dashscope.api_key_file_path=api_key_file_path in code, ' # noqa: E501 - 'or you can set api key file path via environment variable DASHSCOPE_API_KEY_FILE_PATH, ' # noqa: E501 - 'You can call save_api_key to api_key_file_path or default path(~/.dashscope/api_key).' # noqa: E501 + "No api key provided. You can set by dashscope.api_key = your_api_key in code, " # noqa: E501 # pylint: disable=line-too-long + "or you can set it via environment variable DASHSCOPE_API_KEY= your_api_key. " # noqa: E501 + "You can store your api key to a file, and use dashscope.api_key_file_path=api_key_file_path in code, " # noqa: E501 # pylint: disable=line-too-long + "or you can set api key file path via environment variable DASHSCOPE_API_KEY_FILE_PATH, " # noqa: E501 # pylint: disable=line-too-long + "You can call save_api_key to api_key_file_path or default path(~/.dashscope/api_key).", # noqa: E501 # pylint: disable=line-too-long ) def save_api_key(api_key: str, api_key_file_path: Optional[str] = None): if api_key_file_path is None: os.makedirs(DEFAULT_DASHSCOPE_CACHE_PATH, exist_ok=True) - with open(DEFAULT_DASHSCOPE_API_KEY_FILE_PATH, 'w+') as f: + # pylint: disable=unspecified-encoding + with open(DEFAULT_DASHSCOPE_API_KEY_FILE_PATH, "w+") as f: f.write(api_key) else: os.makedirs(os.path.dirname(api_key_file_path), exist_ok=True) - with open(api_key_file_path, 'w+') as f: + # pylint: disable=unspecified-encoding + with open(api_key_file_path, "w+") as f: f.write(api_key) diff --git a/dashscope/common/base_type.py b/dashscope/common/base_type.py index 8a67e65..6274c5d 100644 --- a/dashscope/common/base_type.py +++ b/dashscope/common/base_type.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import dataclasses @@ -9,20 +10,20 @@ def get_object_type(name: str): dashscope_objects = { - 'assistant': dashscope.Assistant, - 'assistant.deleted': dashscope.DeleteResponse, - 'thread.message': dashscope.ThreadMessage, - 'thread.run': dashscope.Run, - 'thread.run.step': dashscope.RunStep, - 'thread.message.file': dashscope.MessageFile, - 'assistant.file': dashscope.AssistantFile, - 'thread': dashscope.Thread, + "assistant": dashscope.Assistant, + "assistant.deleted": dashscope.DeleteResponse, + "thread.message": dashscope.ThreadMessage, + "thread.run": dashscope.Run, + "thread.run.step": dashscope.RunStep, + "thread.message.file": dashscope.MessageFile, + "assistant.file": dashscope.AssistantFile, + "thread": dashscope.Thread, } return dashscope_objects.get(name, None) @dataclass(init=False) -class BaseObjectMixin(object): +class BaseObjectMixin(object): # pylint: disable=useless-object-inheritance __slots__ = () def __init__(self, **kwargs): @@ -35,7 +36,7 @@ def __init__(self, **kwargs): continue if isinstance(v, dict): - object_name = v.get('object', None) + object_name = v.get("object", None) if object_name: object_type = get_object_type(object_name) if object_type: @@ -62,7 +63,7 @@ def _init_list_element_recursive(self, field, items: list) -> List[Any]: continue if isinstance(item, dict): - object_name = item.get('object', None) + object_name = item.get("object", None) if object_name: object_type = get_object_type(object_name) if object_type: @@ -72,7 +73,9 @@ def _init_list_element_recursive(self, field, items: list) -> List[Any]: else: obj_list.append(item) elif isinstance(item, list): - obj_list.append(self._init_list_element_recursive(item)) + # Recursively initialize nested list elements + # pylint: disable=no-value-for-parameter + obj_list.append(self._init_list_element_recursive(item)) # type: ignore[call-arg] # pylint: disable=line-too-long # noqa: E501 else: obj_list.append(item) return obj_list @@ -106,8 +109,9 @@ def _recursive_to_str__(self, input_object) -> Any: output_object = {} for field in dataclasses.fields(input_object): if hasattr(input_object, field.name): - output_object[field.name] = self._recursive_to_str__( - getattr(input_object, field.name)) + output_object[field.name] = self._recursive_to_str__( # type: ignore[call-overload] # pylint: disable=line-too-long # noqa: E501 + getattr(input_object, field.name), + ) return output_object else: return input_object @@ -131,5 +135,5 @@ class BaseList(BaseObjectMixin): last_id: str first_id: str - def __init__(self, **kwargs): + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) diff --git a/dashscope/common/constants.py b/dashscope/common/constants.py index 8ece711..b8331f0 100644 --- a/dashscope/common/constants.py +++ b/dashscope/common/constants.py @@ -1,92 +1,96 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from http import HTTPStatus from pathlib import Path -DASHSCOPE_API_KEY_ENV = 'DASHSCOPE_API_KEY' -DASHSCOPE_API_KEY_FILE_PATH_ENV = 'DASHSCOPE_API_KEY_FILE_PATH' -DASHSCOPE_API_REGION_ENV = 'DASHSCOPE_API_REGION' -DASHSCOPE_API_VERSION_ENV = 'DASHSCOPE_API_VERSION' +DASHSCOPE_API_KEY_ENV = "DASHSCOPE_API_KEY" +DASHSCOPE_API_KEY_FILE_PATH_ENV = "DASHSCOPE_API_KEY_FILE_PATH" +DASHSCOPE_API_REGION_ENV = "DASHSCOPE_API_REGION" +DASHSCOPE_API_VERSION_ENV = "DASHSCOPE_API_VERSION" # to disable data inspection # export DASHSCOPE_DISABLE_DATA_INSPECTION=true -DASHSCOPE_DISABLE_DATA_INSPECTION_ENV = 'DASHSCOPE_DISABLE_DATA_INSPECTION' -DEFAULT_DASHSCOPE_CACHE_PATH = Path.home().joinpath('.dashscope') +DASHSCOPE_DISABLE_DATA_INSPECTION_ENV = "DASHSCOPE_DISABLE_DATA_INSPECTION" +DEFAULT_DASHSCOPE_CACHE_PATH = Path.home().joinpath(".dashscope") DEFAULT_DASHSCOPE_API_KEY_FILE_PATH = Path.joinpath( - DEFAULT_DASHSCOPE_CACHE_PATH, 'api_key') + DEFAULT_DASHSCOPE_CACHE_PATH, + "api_key", +) DEFAULT_REQUEST_TIMEOUT_SECONDS = 300 -REQUEST_TIMEOUT_KEYWORD = 'request_timeout' -SERVICE_API_PATH = 'services' -DASHSCOPE_LOGGING_LEVEL_ENV = 'DASHSCOPE_LOGGING_LEVEL' +REQUEST_TIMEOUT_KEYWORD = "request_timeout" +SERVICE_API_PATH = "services" +DASHSCOPE_LOGGING_LEVEL_ENV = "DASHSCOPE_LOGGING_LEVEL" # task config keys. -PROMPT = 'prompt' -MESSAGES = 'messages' -NEGATIVE_PROMPT = 'negative_prompt' -HISTORY = 'history' -CUSTOMIZED_MODEL_ID = 'customized_model_id' -IMAGES = 'images' -REFERENCE_VIDEO_URLS = 'reference_video_urls' -TEXT_EMBEDDING_INPUT_KEY = 'texts' -SERVICE_503_MESSAGE = 'Service temporarily unavailable, possibly overloaded or not ready.' # noqa E501 +PROMPT = "prompt" +MESSAGES = "messages" +NEGATIVE_PROMPT = "negative_prompt" +HISTORY = "history" +CUSTOMIZED_MODEL_ID = "customized_model_id" +IMAGES = "images" +REFERENCE_VIDEO_URLS = "reference_video_urls" +TEXT_EMBEDDING_INPUT_KEY = "texts" +SERVICE_503_MESSAGE = "Service temporarily unavailable, possibly overloaded or not ready." # noqa E501 # pylint: disable=line-too-long WEBSOCKET_ERROR_CODE = 44 -SSE_CONTENT_TYPE = 'text/event-stream' -DEPRECATED_MESSAGE = 'history and auto_history are deprecated for qwen serial models and will be remove in future, use messages' # noqa E501 -SCENE = 'scene' -MESSAGE = 'message' -REQUEST_CONTENT_TEXT = 'text' -REQUEST_CONTENT_IMAGE = 'image' -REQUEST_CONTENT_AUDIO = 'audio' -FILE_PATH_SCHEMA = 'file://' +SSE_CONTENT_TYPE = "text/event-stream" +DEPRECATED_MESSAGE = "history and auto_history are deprecated for qwen serial models and will be remove in future, use messages" # noqa E501 # pylint: disable=line-too-long +SCENE = "scene" +MESSAGE = "message" +REQUEST_CONTENT_TEXT = "text" +REQUEST_CONTENT_IMAGE = "image" +REQUEST_CONTENT_AUDIO = "audio" +FILE_PATH_SCHEMA = "file://" ENCRYPTION_AES_SECRET_KEY_BYTES = 32 ENCRYPTION_AES_IV_LENGTH = 12 REPEATABLE_STATUS = [ - HTTPStatus.SERVICE_UNAVAILABLE, HTTPStatus.GATEWAY_TIMEOUT + HTTPStatus.SERVICE_UNAVAILABLE, + HTTPStatus.GATEWAY_TIMEOUT, ] class FilePurpose: - fine_tune = 'fine_tune' - assistants = 'assistants' + fine_tune = "fine_tune" + assistants = "assistants" class DeploymentStatus: - DEPLOYING = 'DEPLOYING' - SERVING = 'RUNNING' - DELETING = 'DELETING' - FAILED = 'FAILED' - PENDING = 'PENDING' + DEPLOYING = "DEPLOYING" + SERVING = "RUNNING" + DELETING = "DELETING" + FAILED = "FAILED" + PENDING = "PENDING" class ApiProtocol: - WEBSOCKET = 'websocket' - HTTP = 'http' - HTTPS = 'https' + WEBSOCKET = "websocket" + HTTP = "http" + HTTPS = "https" class HTTPMethod: - GET = 'GET' - HEAD = 'HEAD' - POST = 'POST' - PUT = 'PUT' - DELETE = 'DELETE' - CONNECT = 'CONNECT' - OPTIONS = 'OPTIONS' - TRACE = 'TRACE' - PATCH = 'PATCH' + GET = "GET" + HEAD = "HEAD" + POST = "POST" + PUT = "PUT" + DELETE = "DELETE" + CONNECT = "CONNECT" + OPTIONS = "OPTIONS" + TRACE = "TRACE" + PATCH = "PATCH" class TaskStatus: - PENDING = 'PENDING' - SUSPENDED = 'SUSPENDED' - SUCCEEDED = 'SUCCEEDED' - CANCELED = 'CANCELED' - RUNNING = 'RUNNING' - FAILED = 'FAILED' - UNKNOWN = 'UNKNOWN' - - -class Tasks(object): - TextGeneration = 'text-generation' - AutoSpeechRecognition = 'asr' + PENDING = "PENDING" + SUSPENDED = "SUSPENDED" + SUCCEEDED = "SUCCEEDED" + CANCELED = "CANCELED" + RUNNING = "RUNNING" + FAILED = "FAILED" + UNKNOWN = "UNKNOWN" + + +class Tasks(object): # pylint: disable=useless-object-inheritance + TextGeneration = "text-generation" + AutoSpeechRecognition = "asr" diff --git a/dashscope/common/env.py b/dashscope/common/env.py index 0006827..598fba0 100644 --- a/dashscope/common/env.py +++ b/dashscope/common/env.py @@ -1,22 +1,27 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import os -from dashscope.common.constants import (DASHSCOPE_API_KEY_ENV, - DASHSCOPE_API_KEY_FILE_PATH_ENV, - DASHSCOPE_API_REGION_ENV, - DASHSCOPE_API_VERSION_ENV) +from dashscope.common.constants import ( + DASHSCOPE_API_KEY_ENV, + DASHSCOPE_API_KEY_FILE_PATH_ENV, + DASHSCOPE_API_REGION_ENV, + DASHSCOPE_API_VERSION_ENV, +) -api_region = os.environ.get(DASHSCOPE_API_REGION_ENV, 'cn-beijing') -api_version = os.environ.get(DASHSCOPE_API_VERSION_ENV, 'v1') +api_region = os.environ.get(DASHSCOPE_API_REGION_ENV, "cn-beijing") +api_version = os.environ.get(DASHSCOPE_API_VERSION_ENV, "v1") # read the api key from env api_key = os.environ.get(DASHSCOPE_API_KEY_ENV) api_key_file_path = os.environ.get(DASHSCOPE_API_KEY_FILE_PATH_ENV) # define api base url, ensure end / base_http_api_url = os.environ.get( - 'DASHSCOPE_HTTP_BASE_URL', - 'https://dashscope.aliyuncs.com/api/%s' % (api_version)) + "DASHSCOPE_HTTP_BASE_URL", + f"https://dashscope.aliyuncs.com/api/{api_version}", +) base_websocket_api_url = os.environ.get( - 'DASHSCOPE_WEBSOCKET_BASE_URL', - 'wss://dashscope.aliyuncs.com/api-ws/%s/inference' % (api_version)) + "DASHSCOPE_WEBSOCKET_BASE_URL", + f"wss://dashscope.aliyuncs.com/api-ws/{api_version}/inference", +) diff --git a/dashscope/common/error.py b/dashscope/common/error.py index 074ac2f..2b49cc7 100644 --- a/dashscope/common/error.py +++ b/dashscope/common/error.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. @@ -45,7 +46,7 @@ class UnsupportedApiProtocol(DashScopeException): pass -class NotImplemented(DashScopeException): +class NotImplemented(DashScopeException): # pylint: disable=redefined-builtin pass @@ -66,37 +67,46 @@ def __init__(self, **kwargs): self.message = None self.code = None self.request_id = None - if 'message' in kwargs: + if "message" in kwargs: import json - msg = json.loads(kwargs['message']) - if 'request_id' in msg: - self.request_id = msg['request_id'] - if 'code' in msg: - self.code = msg['code'] - if 'message' in msg: - self.message = msg['message'] + + msg = json.loads(kwargs["message"]) + if "request_id" in msg: + self.request_id = msg["request_id"] + if "code" in msg: + self.code = msg["code"] + if "message" in msg: + self.message = msg["message"] def __str__(self): - msg = 'Request failed, request_id: %s, code: %s, message: %s' % ( # noqa E501 - self.request_id, self.code, self.message) + msg = ( + f"Request failed, request_id: {self.request_id}, " + f"code: {self.code}, message: {self.message}" + ) return msg # for server send generation or inference error. class RequestFailure(DashScopeException): - def __init__(self, - request_id=None, - message=None, - name=None, - http_code=None): + def __init__( + self, + request_id=None, + message=None, + name=None, + http_code=None, + ): self.request_id = request_id self.message = message self.name = name self.http_code = http_code def __str__(self): - msg = 'Request failed, request_id: %s, http_code: %s error_name: %s, error_message: %s' % ( # noqa E501 - self.request_id, self.http_code, self.name, self.message) + # pylint: disable=line-too-long + msg = ( + f"Request failed, request_id: {self.request_id}, " + f"http_code: {self.http_code} error_name: {self.name}, " + f"error_message: {self.message}" + ) return msg diff --git a/dashscope/common/logging.py b/dashscope/common/logging.py index e99c0d3..a5ca4ab 100644 --- a/dashscope/common/logging.py +++ b/dashscope/common/logging.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import logging @@ -5,23 +6,23 @@ from dashscope.common.constants import DASHSCOPE_LOGGING_LEVEL_ENV -logger = logging.getLogger('dashscope') +logger = logging.getLogger("dashscope") def enable_logging(): level = os.environ.get(DASHSCOPE_LOGGING_LEVEL_ENV, None) if level is not None: # set logging level. - if level not in ['info', 'debug']: + if level not in ["info", "debug"]: # set logging level env, but invalid value, use default. - level = 'info' - if level == 'info': + level = "info" + if level == "info": logger.setLevel(logging.INFO) else: logger.setLevel(logging.DEBUG) # set default logging handler console_handler = logging.StreamHandler() formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(filename)s - %(funcName)s - %(lineno)d - %(levelname)s - %(message)s' # noqa E501 + "%(asctime)s - %(name)s - %(filename)s - %(funcName)s - %(lineno)d - %(levelname)s - %(message)s", # noqa E501 # pylint: disable=line-too-long ) console_handler.setFormatter(formatter) logger.addHandler(console_handler) diff --git a/dashscope/common/message_manager.py b/dashscope/common/message_manager.py index 22613e3..9b546b9 100644 --- a/dashscope/common/message_manager.py +++ b/dashscope/common/message_manager.py @@ -1,27 +1,30 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from collections import deque from typing import List -from dashscope.api_entities.dashscope_response import (ConversationResponse, - GenerationResponse, - Message) +from dashscope.api_entities.dashscope_response import ( + ConversationResponse, + GenerationResponse, + Message, +) -class MessageManager(object): +class MessageManager(object): # pylint: disable=useless-object-inheritance DEFAULT_MAXIMUM_MESSAGES = 100 def __init__(self, max_length: int = None): if max_length is None: self._dq = deque(maxlen=MessageManager.DEFAULT_MAXIMUM_MESSAGES) else: - self._dq = deque(maxlen=max_length) + self._dq = deque(maxlen=max_length) # type: ignore[has-type] def add_generation_response(self, response: GenerationResponse): - self._dq.append(Message.from_generation_response(response)) + self._dq.append(Message.from_generation_response(response)) # type: ignore[has-type] # pylint: disable=line-too-long # noqa: E501 def add_conversation_response(self, response: ConversationResponse): - self._dq.append(Message.from_conversation_response(response)) + self._dq.append(Message.from_conversation_response(response)) # type: ignore[has-type] # pylint: disable=line-too-long # noqa: E501 def add(self, message: Message): """Add message to message manager @@ -29,7 +32,7 @@ def add(self, message: Message): Args: message (Message): The message to add. """ - self._dq.append(message) + self._dq.append(message) # type: ignore[has-type] def get(self) -> List[Message]: - return list(self._dq) + return list(self._dq) # type: ignore[has-type] diff --git a/dashscope/common/utils.py b/dashscope/common/utils.py index 1f0b151..bd4b93d 100644 --- a/dashscope/common/utils.py +++ b/dashscope/common/utils.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio @@ -22,7 +23,7 @@ def is_validate_fine_tune_file(file_path): - with open(file_path, encoding='utf-8') as f: + with open(file_path, encoding="utf-8") as f: for line in f: try: json.loads(line) @@ -41,9 +42,9 @@ def _get_task_group_and_task(module_name): Returns: (str, str): task_group and task """ - pkg, task = module_name.rsplit('.', 1) - task = task.replace('_', '-') - _, task_group = pkg.rsplit('.', 1) + pkg, task = module_name.rsplit(".", 1) + task = task.replace("_", "-") + _, task_group = pkg.rsplit(".", 1) return task_group, task @@ -57,7 +58,7 @@ def is_path(path: str): bool: If path return True, otherwise False. """ url_parsed = urlparse(path) - if url_parsed.scheme in ('file', ''): + if url_parsed.scheme in ("file", ""): return os.path.exists(url_parsed.path) else: return False @@ -73,7 +74,8 @@ def is_url(url: str): bool: If is url return True, otherwise False. """ url_parsed = urlparse(url) - if url_parsed.scheme in ('http', 'https', 'oss'): + # pylint: disable=simplifiable-if-statement + if url_parsed.scheme in ("http", "https", "oss"): return True else: return False @@ -104,9 +106,11 @@ def iter_thread(loop, message_queue): break message_queue = queue.Queue() - x = threading.Thread(target=iter_thread, - args=(loop, message_queue), - name='iter_async_thread') + x = threading.Thread( + target=iter_thread, + args=(loop, message_queue), + name="iter_async_thread", + ) x.start() while True: finished, error, obj = message_queue.get() @@ -114,13 +118,12 @@ def iter_thread(loop, message_queue): if error is not None: yield DashScopeAPIResponse( -1, - '', - 'Unknown', - message='Error type: %s, message: %s' % - (type(error), error)) + "", + "Unknown", + message=f"Error type: {type(error)}, message: {error}", + ) break - else: - yield obj + yield obj # pylint: disable=no-else-break def async_to_sync(async_generator): @@ -129,69 +132,73 @@ def async_to_sync(async_generator): def get_user_agent(): - ua = 'dashscope/%s; python/%s; platform/%s; processor/%s' % ( - __version__, - platform.python_version(), - platform.platform(), - platform.processor(), + ua = ( + f"dashscope/{__version__}; python/{platform.python_version()}; " + f"platform/{platform.platform()}; " + f"processor/{platform.processor()}" ) return ua def default_headers(api_key: str = None) -> Dict[str, str]: - ua = 'dashscope/%s; python/%s; platform/%s; processor/%s' % ( - __version__, - platform.python_version(), - platform.platform(), - platform.processor(), + ua = ( + f"dashscope/{__version__}; python/{platform.python_version()}; " + f"platform/{platform.platform()}; " + f"processor/{platform.processor()}" ) - headers = {'user-agent': ua} + headers = {"user-agent": ua} if api_key is None: api_key = get_default_api_key() - headers['Authorization'] = 'Bearer %s' % api_key - headers['Accept'] = 'application/json' + headers["Authorization"] = f"Bearer {api_key}" + headers["Accept"] = "application/json" return headers def join_url(base_url, *args): - if not base_url.endswith('/'): - base_url = base_url + '/' + if not base_url.endswith("/"): + base_url = base_url + "/" url = base_url for arg in args: if arg is not None: - url += arg + '/' + url += arg + "/" return url[:-1] async def _handle_aiohttp_response(response: aiohttp.ClientResponse): - request_id = '' + request_id = "" if response.status == HTTPStatus.OK: json_content = await response.json() - if 'request_id' in json_content: - request_id = json_content['request_id'] - return DashScopeAPIResponse(request_id=request_id, - status_code=HTTPStatus.OK, - output=json_content) + if "request_id" in json_content: + request_id = json_content["request_id"] + return DashScopeAPIResponse( + request_id=request_id, + status_code=HTTPStatus.OK, + output=json_content, + ) else: - if 'application/json' in response.content_type: + if "application/json" in response.content_type: error = await response.json() - msg = '' - if 'message' in error: - msg = error['message'] - if 'request_id' in error: - request_id = error['request_id'] - return DashScopeAPIResponse(request_id=request_id, - status_code=response.status, - output=None, - code=error['code'], - message=msg) + msg = "" + if "message" in error: + msg = error["message"] + if "request_id" in error: + request_id = error["request_id"] + return DashScopeAPIResponse( + request_id=request_id, + status_code=response.status, + output=None, + code=error["code"], + message=msg, + ) else: msg = await response.read() - return DashScopeAPIResponse(request_id=request_id, - status_code=response.status, - output=None, - code='Unknown', - message=msg) + return DashScopeAPIResponse( + request_id=request_id, + status_code=response.status, + output=None, + code="Unknown", + message=msg, + ) @dataclass @@ -200,7 +207,12 @@ class SSEEvent: eventType: str data: str - def __init__(self, id: str, type: str, data: str): + def __init__( # pylint: disable=redefined-builtin + self, + id: str, + type: str, + data: str, + ): self.id = id self.eventType = type self.data = data @@ -210,27 +222,27 @@ def _handle_stream(response: requests.Response): # TODO define done message. is_error = False status_code = HTTPStatus.BAD_REQUEST - event = SSEEvent(None, None, None) + event = SSEEvent(None, None, None) # type: ignore[arg-type] eventType = None for line in response.iter_lines(): if line: - line = line.decode('utf8') - line = line.rstrip('\n').rstrip('\r') - if line.startswith('id:'): - id = line[len('id:'):] + line = line.decode("utf8") + line = line.rstrip("\n").rstrip("\r") + if line.startswith("id:"): + id = line[len("id:") :] # pylint: disable=redefined-builtin event.id = id.strip() - elif line.startswith('event:'): - eventType = line[len('event:'):] + elif line.startswith("event:"): + eventType = line[len("event:") :] event.eventType = eventType.strip() - if eventType == 'error': + if eventType == "error": is_error = True - elif line.startswith('status:'): - status_code = line[len('status:'):] + elif line.startswith("status:"): + status_code = line[len("status:") :] status_code = int(status_code.strip()) - elif line.startswith('data:'): - line = line[len('data:'):] + elif line.startswith("data:"): + line = line[len("data:") :] event.data = line.strip() - if eventType is not None and eventType == 'done': + if eventType is not None and eventType == "done": continue yield (is_error, status_code, event) if is_error: @@ -241,52 +253,65 @@ def _handle_stream(response: requests.Response): def _handle_error_message(error, status_code, flattened_output): code = None - msg = '' - request_id = '' + msg = "" + request_id = "" if flattened_output: - error['status_code'] = status_code + error["status_code"] = status_code return error - if 'message' in error: - msg = error['message'] - if 'msg' in error: - msg = error['msg'] - if 'code' in error: - code = error['code'] - if 'request_id' in error: - request_id = error['request_id'] - return DashScopeAPIResponse(request_id=request_id, - status_code=status_code, - code=code, - message=msg) + if "message" in error: + msg = error["message"] + if "msg" in error: + msg = error["msg"] + if "code" in error: + code = error["code"] + if "request_id" in error: + request_id = error["request_id"] + return DashScopeAPIResponse( + request_id=request_id, + status_code=status_code, + code=code, + message=msg, + ) def _handle_http_failed_response( - response: requests.Response, - flattened_output: bool = False) -> DashScopeAPIResponse: - request_id = '' - if 'application/json' in response.headers.get('content-type', ''): + response: requests.Response, + flattened_output: bool = False, +) -> DashScopeAPIResponse: + request_id = "" + if "application/json" in response.headers.get("content-type", ""): error = response.json() - return _handle_error_message(error, response.status_code, - flattened_output) - elif SSE_CONTENT_TYPE in response.headers.get('content-type', ''): - msgs = response.content.decode('utf-8').split('\n') + return _handle_error_message( + error, + response.status_code, + flattened_output, + ) + elif SSE_CONTENT_TYPE in response.headers.get("content-type", ""): + msgs = response.content.decode("utf-8").split("\n") for msg in msgs: - if msg.startswith('data:'): - error = json.loads(msg.replace('data:', '').strip()) - return _handle_error_message(error, response.status_code, - flattened_output) - return DashScopeAPIResponse(request_id=request_id, - status_code=response.status_code, - code='Unknown', - message=msgs) + if msg.startswith("data:"): + error = json.loads(msg.replace("data:", "").strip()) + return _handle_error_message( + error, + response.status_code, + flattened_output, + ) + return DashScopeAPIResponse( + request_id=request_id, + status_code=response.status_code, + code="Unknown", + message=msgs, + ) else: - msg = response.content.decode('utf-8') + msg = response.content.decode("utf-8") if flattened_output: - return {'status_code': response.status_code, 'message': msg} - return DashScopeAPIResponse(request_id=request_id, - status_code=response.status_code, - code='Unknown', - message=msg) + return {"status_code": response.status_code, "message": msg} # type: ignore[return-value] # pylint: disable=line-too-long # noqa: E501 + return DashScopeAPIResponse( + request_id=request_id, + status_code=response.status_code, + code="Unknown", + message=msg, + ) async def _handle_aio_stream(response): @@ -295,15 +320,15 @@ async def _handle_aio_stream(response): status_code = HTTPStatus.BAD_REQUEST async for line in response.content: if line: - line = line.decode('utf8') - line = line.rstrip('\n').rstrip('\r') - if line.startswith('event:error'): + line = line.decode("utf8") + line = line.rstrip("\n").rstrip("\r") + if line.startswith("event:error"): is_error = True - elif line.startswith('status:'): - status_code = line[len('status:'):] + elif line.startswith("status:"): + status_code = line[len("status:") :] status_code = int(status_code.strip()) - elif line.startswith('data:'): - line = line[len('data:'):] + elif line.startswith("data:"): + line = line[len("data:") :] yield (is_error, status_code, line) if is_error: break @@ -312,10 +337,11 @@ async def _handle_aio_stream(response): async def _handle_aiohttp_failed_response( - response: requests.Response, - flattened_output: bool = False) -> DashScopeAPIResponse: - request_id = '' - if 'application/json' in response.content_type: + response: requests.Response, + flattened_output: bool = False, +) -> DashScopeAPIResponse: + request_id = "" + if "application/json" in response.content_type: error = await response.json() return _handle_error_message(error, response.status, flattened_output) elif SSE_CONTENT_TYPE in response.content_type: @@ -323,17 +349,21 @@ async def _handle_aiohttp_failed_response( error = json.loads(data) return _handle_error_message(error, response.status, flattened_output) else: - msg = response.content.decode('utf-8') + msg = response.content.decode("utf-8") if flattened_output: - return {'status_code': response.status, 'message': msg} - return DashScopeAPIResponse(request_id=request_id, - status_code=response.status, - code='Unknown', - message=msg) - - -def _handle_http_response(response: requests.Response, - flattened_output: bool = False): + return {"status_code": response.status, "message": msg} # type: ignore[return-value] # pylint: disable=line-too-long # noqa: E501 + return DashScopeAPIResponse( + request_id=request_id, + status_code=response.status, + code="Unknown", + message=msg, + ) + + +def _handle_http_response( + response: requests.Response, + flattened_output: bool = False, +): response = _handle_http_stream_response(response, flattened_output) _, output = next(response) try: @@ -343,11 +373,16 @@ def _handle_http_response(response: requests.Response, return output -def _handle_http_stream_response(response: requests.Response, - flattened_output: bool = False): - request_id = '' - if (response.status_code == HTTPStatus.OK - and SSE_CONTENT_TYPE in response.headers.get('content-type', '')): +# pylint: disable=R1702,too-many-branches,too-many-statements +def _handle_http_stream_response( + response: requests.Response, + flattened_output: bool = False, +): + request_id = "" + if ( + response.status_code == HTTPStatus.OK + and SSE_CONTENT_TYPE in response.headers.get("content-type", "") + ): for is_error, status_code, event in _handle_stream(response): if not is_error: try: @@ -355,41 +390,43 @@ def _handle_http_stream_response(response: requests.Response, usage = None msg = json.loads(event.data) if flattened_output: - msg['status_code'] = response.status_code + msg["status_code"] = response.status_code yield event.eventType, msg else: - logger.debug('Stream message: %s' % msg) + logger.debug("Stream message: %s", msg) if not is_error: - if 'output' in msg: - output = msg['output'] - if 'usage' in msg: - usage = msg['usage'] - if 'request_id' in msg: - request_id = msg['request_id'] + if "output" in msg: + output = msg["output"] + if "usage" in msg: + usage = msg["usage"] + if "request_id" in msg: + request_id = msg["request_id"] yield event.eventType, DashScopeAPIResponse( request_id=request_id, status_code=HTTPStatus.OK, output=output, - usage=usage) + usage=usage, + ) except json.JSONDecodeError as e: if flattened_output: yield event.eventType, { - 'status_code': response.status_code, - 'message': e.message + "status_code": response.status_code, + "message": e.message, } else: yield event.eventType, DashScopeAPIResponse( request_id=request_id, status_code=HTTPStatus.BAD_REQUEST, output=None, - code='Unknown', - message=event.data) + code="Unknown", + message=event.data, + ) continue else: if flattened_output: yield event.eventType, { - 'status_code': status_code, - 'message': event.data + "status_code": status_code, + "message": event.data, } else: msg = json.loads(event.eventType) @@ -397,42 +434,49 @@ def _handle_http_stream_response(response: requests.Response, request_id=request_id, status_code=status_code, output=None, - code=msg['code'] - if 'code' in msg else None, # noqa E501 - message=msg['message'] - if 'message' in msg else None) # noqa E501 - elif response.status_code == HTTPStatus.OK or response.status_code == HTTPStatus.CREATED: + code=msg["code"] + if "code" in msg + else None, # noqa E501 + message=msg["message"] if "message" in msg else None, + ) # noqa E501 + # pylint: disable=consider-using-in + elif ( + response.status_code == HTTPStatus.OK + or response.status_code == HTTPStatus.CREATED + ): json_content = response.json() if flattened_output: - json_content['status_code'] = response.status_code + json_content["status_code"] = response.status_code yield None, json_content else: output = None usage = None code = None - msg = '' - if 'data' in json_content: - output = json_content['data'] - if 'code' in json_content: - code = json_content['code'] - if 'message' in json_content: - msg = json_content['message'] - if 'output' in json_content: - output = json_content['output'] - if 'usage' in json_content: - usage = json_content['usage'] - if 'request_id' in json_content: - request_id = json_content['request_id'] - json_content.pop('request_id', None) - - if 'data' not in json_content and 'output' not in json_content: + msg = "" + if "data" in json_content: + output = json_content["data"] + if "code" in json_content: + code = json_content["code"] + if "message" in json_content: + msg = json_content["message"] + if "output" in json_content: + output = json_content["output"] + if "usage" in json_content: + usage = json_content["usage"] + if "request_id" in json_content: + request_id = json_content["request_id"] + json_content.pop("request_id", None) + + if "data" not in json_content and "output" not in json_content: output = json_content - yield None, DashScopeAPIResponse(request_id=request_id, - status_code=response.status_code, - code=code, - output=output, - usage=usage, - message=msg) + yield None, DashScopeAPIResponse( + request_id=request_id, + status_code=response.status_code, + code=code, # type: ignore[arg-type] + output=output, + usage=usage, + message=msg, + ) else: yield None, _handle_http_failed_response(response, flattened_output) diff --git a/dashscope/customize/customize_types.py b/dashscope/customize/customize_types.py index 35c2723..d33408b 100644 --- a/dashscope/customize/customize_types.py +++ b/dashscope/customize/customize_types.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from dataclasses import dataclass @@ -6,7 +7,7 @@ from dashscope.common.base_type import BaseObjectMixin -__all__ = ['Deployment', 'FineTune', 'DeploymentList', 'FineTuneList'] +__all__ = ["Deployment", "FineTune", "DeploymentList", "FineTuneList"] @dataclass(init=False) @@ -15,7 +16,7 @@ class DashScopeBaseList(BaseObjectMixin): page_size: int total: int - def __init__(self, **kwargs): + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) @@ -26,7 +27,7 @@ class DashScopeBase(BaseObjectMixin): code: str message: str - def __init__(self, **kwargs): + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) @@ -50,7 +51,7 @@ class FineTuneOutput(BaseObjectMixin): group: str usage: int - def __init__(self, **kwargs): + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) @@ -60,9 +61,9 @@ class FineTune(DashScopeBase): usage: Dict def __init__(self, **kwargs): - status_code = kwargs.get('status_code', None) + status_code = kwargs.get("status_code", None) if status_code == HTTPStatus.OK: - self.output = FineTuneOutput(**kwargs.pop('output', {})) + self.output = FineTuneOutput(**kwargs.pop("output", {})) super().__init__(**kwargs) @@ -72,7 +73,7 @@ class FineTuneListOutput(DashScopeBaseList): def __init__(self, **kwargs): self.jobs = [] - for job in kwargs.pop('jobs', []): + for job in kwargs.pop("jobs", []): self.jobs.append(FineTuneOutput(**job)) super().__init__(**kwargs) @@ -82,9 +83,9 @@ class FineTuneList(DashScopeBase): output: FineTuneListOutput def __init__(self, **kwargs): - status_code = kwargs.get('status_code', None) + status_code = kwargs.get("status_code", None) if status_code == HTTPStatus.OK: - self.output = FineTuneListOutput(**kwargs.pop('output', {})) + self.output = FineTuneListOutput(**kwargs.pop("output", {})) super().__init__(**kwargs) @@ -98,9 +99,9 @@ class FineTuneCancel(DashScopeBase): output: CancelDeleteStatus def __init__(self, **kwargs): - status_code = kwargs.get('status_code', None) + status_code = kwargs.get("status_code", None) if status_code == HTTPStatus.OK: - self.output = CancelDeleteStatus(**kwargs.pop('output', {})) + self.output = CancelDeleteStatus(**kwargs.pop("output", {})) super().__init__(**kwargs) @@ -109,9 +110,9 @@ class FineTuneDelete(DashScopeBase): output: CancelDeleteStatus def __init__(self, **kwargs): - status_code = kwargs.get('status_code', None) + status_code = kwargs.get("status_code", None) if status_code == HTTPStatus.OK: - self.output = CancelDeleteStatus(**kwargs.pop('output', {})) + self.output = CancelDeleteStatus(**kwargs.pop("output", {})) super().__init__(**kwargs) @@ -120,9 +121,9 @@ class FineTuneEvent(DashScopeBase): output: str def __init__(self, **kwargs): - status_code = kwargs.get('status_code', None) + status_code = kwargs.get("status_code", None) if status_code == HTTPStatus.OK: - self.output = kwargs.pop('output', {}) + self.output = kwargs.pop("output", {}) super().__init__(**kwargs) @@ -142,7 +143,7 @@ class DeploymentOutput(BaseObjectMixin): modifier: str creator: str - def __init__(self, **kwargs): + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) @@ -151,7 +152,7 @@ class Deployment(DashScopeBase): output: DeploymentOutput def __init__(self, **kwargs): - output = kwargs.pop('output', {}) + output = kwargs.pop("output", {}) if output: self.output = DeploymentOutput(**output) else: @@ -165,7 +166,7 @@ class DeploymentListOutput(DashScopeBaseList): def __init__(self, **kwargs): self.deployments = [] - for job in kwargs.pop('deployments', []): + for job in kwargs.pop("deployments", []): self.deployments.append(DeploymentOutput(**job)) super().__init__(**kwargs) @@ -175,9 +176,9 @@ class DeploymentList(BaseObjectMixin): output: DeploymentListOutput def __init__(self, **kwargs): - status_code = kwargs.get('status_code', None) + status_code = kwargs.get("status_code", None) if status_code == HTTPStatus.OK: - self.output = DeploymentListOutput(**kwargs.pop('output', {})) + self.output = DeploymentListOutput(**kwargs.pop("output", {})) super().__init__(**kwargs) @@ -186,7 +187,7 @@ class DeploymentDelete(DashScopeBase): output: CancelDeleteStatus def __init__(self, **kwargs): - status_code = kwargs.get('status_code', None) + status_code = kwargs.get("status_code", None) if status_code == HTTPStatus.OK: - self.output = CancelDeleteStatus(**kwargs.pop('output', {})) + self.output = CancelDeleteStatus(**kwargs.pop("output", {})) super().__init__(**kwargs) diff --git a/dashscope/customize/deployments.py b/dashscope/customize/deployments.py index f0a0fab..e0add2b 100644 --- a/dashscope/customize/deployments.py +++ b/dashscope/customize/deployments.py @@ -1,25 +1,45 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -from dashscope.client.base_api import (CreateMixin, DeleteMixin, GetMixin, - ListMixin, PutMixin, StreamEventMixin) -from dashscope.customize.customize_types import (Deployment, DeploymentDelete, - DeploymentList) - - -class Deployments(CreateMixin, DeleteMixin, ListMixin, GetMixin, - StreamEventMixin, PutMixin): - SUB_PATH = 'deployments' +from dashscope.client.base_api import ( + CreateMixin, + DeleteMixin, + GetMixin, + ListMixin, + PutMixin, + StreamEventMixin, +) +from dashscope.customize.customize_types import ( + Deployment, + DeploymentDelete, + DeploymentList, +) + + +class Deployments( + CreateMixin, + DeleteMixin, + ListMixin, + GetMixin, + StreamEventMixin, + PutMixin, +): + SUB_PATH = "deployments" """Deploy a model. """ + @classmethod - def call(cls, - model: str, - capacity: int, - version: str = None, - suffix: str = None, - api_key: str = None, - workspace: str = None, - **kwargs) -> Deployment: + # type: ignore[override] + def call( # pylint: disable=arguments-renamed # type: ignore[override] + cls, + model: str, + capacity: int, + version: str = None, + suffix: str = None, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> Deployment: """Call to deployment a model service. Args: @@ -36,25 +56,29 @@ def call(cls, Returns: Deployment: _description_ """ - req = {'model_name': model, 'capacity': capacity} + req = {"model_name": model, "capacity": capacity} if version is not None: - req['model_version'] = version + req["model_version"] = version if suffix is not None: - req['suffix'] = suffix - response = super().call(req, - api_key=api_key, - workspace=workspace, - **kwargs) + req["suffix"] = suffix + response = super().call( + req, + api_key=api_key, + workspace=workspace, + **kwargs, + ) return Deployment(**response) @classmethod - def list(cls, - page_no=1, - page_size=10, - api_key: str = None, - workspace: str = None, - **kwargs) -> DeploymentList: + def list( # type: ignore[override] + cls, + page_no=1, + page_size=10, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DeploymentList: """List deployments. Args: @@ -67,19 +91,23 @@ def list(cls, Returns: Deployment: The deployment list. """ - response = super().list(page_no=page_no, - page_size=page_size, - api_key=api_key, - workspace=workspace, - **kwargs) + response = super().list( + page_no=page_no, + page_size=page_size, + api_key=api_key, + workspace=workspace, + **kwargs, + ) return DeploymentList(**response) @classmethod - def get(cls, - deployed_model: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> Deployment: + def get( # type: ignore[override] + cls, + deployed_model: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> Deployment: """Get model deployment information. Args: @@ -90,18 +118,22 @@ def get(cls, Returns: Deployment: The deployment information. """ - response = super().get(deployed_model, - api_key=api_key, - workspace=workspace, - **kwargs) + response = super().get( + deployed_model, + api_key=api_key, + workspace=workspace, + **kwargs, + ) return Deployment(**response) @classmethod - def delete(cls, - deployed_model: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> DeploymentDelete: + def delete( # type: ignore[override] + cls, + deployed_model: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DeploymentDelete: """Delete model deployment. Args: @@ -112,19 +144,23 @@ def delete(cls, Returns: Deployment: The delete result. """ - response = super().delete(deployed_model, - api_key=api_key, - workspace=workspace, - **kwargs) + response = super().delete( + deployed_model, + api_key=api_key, + workspace=workspace, + **kwargs, + ) return DeploymentDelete(**response) @classmethod - def scale(cls, - deployed_model: str, - capacity: int, - api_key: str = None, - workspace: str = None, - **kwargs) -> Deployment: + def scale( + cls, + deployed_model: str, + capacity: int, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> Deployment: """Scaling model deployment. Args: @@ -135,12 +171,14 @@ def scale(cls, Returns: Deployment: The delete result. """ - req = {'deployed_model': deployed_model, 'capacity': capacity} - path = '%s/%s/scale' % (cls.SUB_PATH.lower(), deployed_model) - response = super().put(deployed_model, - req, - path=path, - api_key=api_key, - workspace=workspace, - **kwargs) + req = {"deployed_model": deployed_model, "capacity": capacity} + path = f"{cls.SUB_PATH.lower()}/{deployed_model}/scale" + response = super().put( + deployed_model, + req, + path=path, + api_key=api_key, + workspace=workspace, + **kwargs, + ) return Deployment(**response) diff --git a/dashscope/customize/finetunes.py b/dashscope/customize/finetunes.py index fa8ca0d..e3a9a81 100644 --- a/dashscope/customize/finetunes.py +++ b/dashscope/customize/finetunes.py @@ -1,32 +1,54 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import time from http import HTTPStatus from typing import Iterator, Union -from dashscope.client.base_api import (CancelMixin, CreateMixin, DeleteMixin, - GetStatusMixin, ListMixin, LogMixin, - StreamEventMixin) +from dashscope.client.base_api import ( + CancelMixin, + CreateMixin, + DeleteMixin, + GetStatusMixin, + ListMixin, + LogMixin, + StreamEventMixin, +) from dashscope.common.constants import TaskStatus -from dashscope.customize.customize_types import (FineTune, FineTuneCancel, - FineTuneDelete, FineTuneEvent, - FineTuneList) +from dashscope.customize.customize_types import ( + FineTune, + FineTuneCancel, + FineTuneDelete, + FineTuneEvent, + FineTuneList, +) -class FineTunes(CreateMixin, CancelMixin, DeleteMixin, ListMixin, - GetStatusMixin, StreamEventMixin, LogMixin): - SUB_PATH = 'fine-tunes' +class FineTunes( + CreateMixin, + CancelMixin, + DeleteMixin, + ListMixin, + GetStatusMixin, + StreamEventMixin, + LogMixin, +): + SUB_PATH = "fine-tunes" @classmethod - def call(cls, - model: str, - training_file_ids: Union[list, str], - validation_file_ids: Union[list, str] = None, - mode: str = None, - hyper_parameters: dict = {}, - api_key: str = None, - workspace: str = None, - **kwargs) -> FineTune: + # type: ignore[override] + # pylint: disable=arguments-renamed,dangerous-default-value + def call( # noqa: E501 # type: ignore[override] + cls, + model: str, + training_file_ids: Union[list, str], + validation_file_ids: Union[list, str] = None, + mode: str = None, + hyper_parameters: dict = {}, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> FineTune: """Create fine-tune job Args: @@ -49,27 +71,31 @@ def call(cls, if validation_file_ids and isinstance(validation_file_ids, str): validation_file_ids = [validation_file_ids] request = { - 'model': model, - 'training_file_ids': training_file_ids, - 'validation_file_ids': validation_file_ids, - 'hyper_parameters': hyper_parameters if hyper_parameters else {}, + "model": model, + "training_file_ids": training_file_ids, + "validation_file_ids": validation_file_ids, + "hyper_parameters": hyper_parameters if hyper_parameters else {}, } if mode is not None: - request['training_type'] = mode - if 'finetuned_output' in kwargs: - request['finetuned_output'] = kwargs['finetuned_output'] - resp = super().call(request, - api_key=api_key, - workspace=workspace, - **kwargs) + request["training_type"] = mode + if "finetuned_output" in kwargs: + request["finetuned_output"] = kwargs["finetuned_output"] + resp = super().call( + request, + api_key=api_key, + workspace=workspace, + **kwargs, + ) return FineTune(**resp) @classmethod - def cancel(cls, - job_id: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> FineTuneCancel: + def cancel( # type: ignore[override] + cls, + job_id: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> FineTuneCancel: """Cancel a running fine-tune job. Args: @@ -81,19 +107,23 @@ def cancel(cls, Returns: FineTune: The request result. """ - rsp = super().cancel(job_id, - api_key=api_key, - workspace=workspace, - **kwargs) + rsp = super().cancel( + job_id, + api_key=api_key, + workspace=workspace, + **kwargs, + ) return FineTuneCancel(**rsp) @classmethod - def list(cls, - page_no=1, - page_size=10, - api_key: str = None, - workspace: str = None, - **kwargs) -> FineTuneList: + def list( # type: ignore[override] + cls, + page_no=1, + page_size=10, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> FineTuneList: """List fine-tune job. Args: @@ -105,19 +135,23 @@ def list(cls, Returns: FineTune: The fine-tune jobs in the result. """ - response = super().list(page_no=page_no, - page_size=page_size, - api_key=api_key, - workspace=workspace, - **kwargs) + response = super().list( + page_no=page_no, + page_size=page_size, + api_key=api_key, + workspace=workspace, + **kwargs, + ) return FineTuneList(**response) @classmethod - def get(cls, - job_id: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> FineTune: + def get( # type: ignore[override] + cls, + job_id: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> FineTune: """Get fine-tune job information. Args: @@ -128,18 +162,22 @@ def get(cls, Returns: FineTune: The job info """ - response = super().get(job_id, - api_key=api_key, - workspace=workspace, - **kwargs) + response = super().get( + job_id, + api_key=api_key, + workspace=workspace, + **kwargs, + ) return FineTune(**response) @classmethod - def delete(cls, - job_id: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> FineTuneDelete: + def delete( # type: ignore[override] + cls, + job_id: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> FineTuneDelete: """Delete a fine-tune job. Args: @@ -150,18 +188,23 @@ def delete(cls, Returns: FineTune: The delete result. """ - rsp = super().delete(job_id, - api_key=api_key, - workspace=workspace, - **kwargs) + rsp = super().delete( + job_id, + api_key=api_key, + workspace=workspace, + **kwargs, + ) return FineTuneDelete(**rsp) @classmethod - def stream_events(cls, - job_id: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> Iterator[FineTuneEvent]: + def stream_events( # type: ignore[override] # pylint: disable=arguments-renamed # noqa: E501 + # type: ignore[override] + cls, + job_id: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> Iterator[FineTuneEvent]: """Get fine-tune job events. Args: @@ -172,22 +215,26 @@ def stream_events(cls, Returns: FineTune: The job log events. """ - responses = super().stream_events(job_id, - api_key=api_key, - workspace=workspace, - **kwargs) + responses = super().stream_events( + job_id, + api_key=api_key, + workspace=workspace, + **kwargs, + ) for rsp in responses: yield FineTuneEvent(**rsp) @classmethod - def logs(cls, - job_id: str, - *, - offset=1, - line=1000, - api_key: str = None, - workspace: str = None, - **kwargs) -> FineTune: + def logs( # type: ignore[override] + cls, + job_id: str, + *, + offset=1, + line=1000, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> FineTune: """Get log of the job. Args: @@ -200,28 +247,35 @@ def logs(cls, Returns: FineTune: The response """ - return super().logs(job_id, - offset=offset, - line=line, - workspace=workspace, - api_key=api_key) + return super().logs( # type: ignore[return-value] + job_id, + offset=offset, + line=line, + workspace=workspace, + api_key=api_key, + ) @classmethod - def wait(cls, - job_id: str, - api_key: str = None, - workspace: str = None, - **kwargs): + def wait( + cls, + job_id: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ): try: while True: - rsp = FineTunes.get(job_id, - api_key=api_key, - workspace=workspace, - **kwargs) + rsp = FineTunes.get( + job_id, + api_key=api_key, + workspace=workspace, + **kwargs, + ) if rsp.status_code == HTTPStatus.OK: - if rsp.output['status'] in [ - TaskStatus.FAILED, TaskStatus.CANCELED, - TaskStatus.SUCCEEDED + if rsp.output["status"] in [ + TaskStatus.FAILED, + TaskStatus.CANCELED, + TaskStatus.SUCCEEDED, ]: return rsp else: @@ -229,6 +283,8 @@ def wait(cls, else: return rsp except Exception: + # pylint: disable=broad-exception-raised,raise-missing-from raise Exception( - 'You can stream output via: dashscope fine_tunes.stream -j %s' - % job_id) + f"You can stream output via: dashscope fine_tunes.stream -j " + f"{job_id}", + ) diff --git a/dashscope/embeddings/__init__.py b/dashscope/embeddings/__init__.py index e71565f..998f1a1 100644 --- a/dashscope/embeddings/__init__.py +++ b/dashscope/embeddings/__init__.py @@ -1,7 +1,12 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from .batch_text_embedding import BatchTextEmbedding from .batch_text_embedding_response import BatchTextEmbeddingResponse from .text_embedding import TextEmbedding -__all__ = [TextEmbedding, BatchTextEmbedding, BatchTextEmbeddingResponse] +__all__ = [ + "TextEmbedding", + "BatchTextEmbedding", + "BatchTextEmbeddingResponse", +] diff --git a/dashscope/embeddings/batch_text_embedding.py b/dashscope/embeddings/batch_text_embedding.py index 653d304..d3af2d1 100644 --- a/dashscope/embeddings/batch_text_embedding.py +++ b/dashscope/embeddings/batch_text_embedding.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Union @@ -6,26 +7,30 @@ from dashscope.client.base_api import BaseAsyncApi from dashscope.common.error import InputRequired from dashscope.common.utils import _get_task_group_and_task -from dashscope.embeddings.batch_text_embedding_response import \ - BatchTextEmbeddingResponse +from dashscope.embeddings.batch_text_embedding_response import ( + BatchTextEmbeddingResponse, +) class BatchTextEmbedding(BaseAsyncApi): - task = 'text-embedding' - function = 'text-embedding' + task = "text-embedding" + function = "text-embedding" """API for async text embedding. """ + class Models: - text_embedding_async_v1 = 'text-embedding-async-v1' - text_embedding_async_v2 = 'text-embedding-async-v2' + text_embedding_async_v1 = "text-embedding-async-v1" + text_embedding_async_v2 = "text-embedding-async-v2" @classmethod - def call(cls, - model: str, - url: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> BatchTextEmbeddingResponse: + def call( # type: ignore[override] + cls, + model: str, + url: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> BatchTextEmbeddingResponse: """Call async text embedding service and get result. Args: @@ -51,19 +56,23 @@ def call(cls, Returns: AsyncTextEmbeddingResponse: The async text embedding task result. """ - return super().call(model, - url, - api_key=api_key, - workspace=workspace, - **kwargs) + return super().call( # type: ignore[return-value] + model, + url, + api_key=api_key, + workspace=workspace, + **kwargs, + ) @classmethod - def async_call(cls, - model: str, - url: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> BatchTextEmbeddingResponse: + def async_call( # type: ignore[override] + cls, + model: str, + url: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> BatchTextEmbeddingResponse: """Create a async text embedding task, and return task information. Args: @@ -92,24 +101,28 @@ def async_call(cls, task id in the response. """ if url is None or not url: - raise InputRequired('url is required!') - input = {'url': url} + raise InputRequired("url is required!") + input = {"url": url} # pylint: disable=redefined-builtin task_group, _ = _get_task_group_and_task(__name__) - response = super().async_call(model=model, - task_group=task_group, - task=BatchTextEmbedding.task, - function=BatchTextEmbedding.function, - api_key=api_key, - input=input, - workspace=workspace, - **kwargs) + response = super().async_call( + model=model, + task_group=task_group, + task=BatchTextEmbedding.task, + function=BatchTextEmbedding.function, + api_key=api_key, + input=input, + workspace=workspace, + **kwargs, + ) return BatchTextEmbeddingResponse.from_api_response(response) @classmethod - def fetch(cls, - task: Union[str, BatchTextEmbeddingResponse], - api_key: str = None, - workspace: str = None) -> BatchTextEmbeddingResponse: + def fetch( # type: ignore[override] + cls, + task: Union[str, BatchTextEmbeddingResponse], + api_key: str = None, + workspace: str = None, + ) -> BatchTextEmbeddingResponse: """Fetch async text embedding task status or result. Args: @@ -125,11 +138,13 @@ def fetch(cls, return BatchTextEmbeddingResponse.from_api_response(response) @classmethod - def wait(cls, - task: Union[str, BatchTextEmbeddingResponse], - api_key: str = None, - workspace: str = None) -> BatchTextEmbeddingResponse: - """Wait for async text embedding task to complete, and return the result. + def wait( # type: ignore[override] + cls, + task: Union[str, BatchTextEmbeddingResponse], + api_key: str = None, + workspace: str = None, + ) -> BatchTextEmbeddingResponse: + """Wait for async text embedding task to complete, and return the result. # noqa: E501 Args: task (Union[str, AsyncTextEmbeddingResponse]): The task_id or @@ -144,10 +159,12 @@ def wait(cls, return BatchTextEmbeddingResponse.from_api_response(response) @classmethod - def cancel(cls, - task: Union[str, BatchTextEmbeddingResponse], - api_key: str = None, - workspace: str = None) -> DashScopeAPIResponse: + def cancel( # type: ignore[override] + cls, + task: Union[str, BatchTextEmbeddingResponse], + api_key: str = None, + workspace: str = None, + ) -> DashScopeAPIResponse: """Cancel async text embedding task. Only tasks whose status is PENDING can be canceled. @@ -163,18 +180,20 @@ def cancel(cls, return super().cancel(task, api_key, workspace=workspace) @classmethod - def list(cls, - start_time: str = None, - end_time: str = None, - model_name: str = None, - api_key_id: str = None, - region: str = None, - status: str = None, - page_no: int = 1, - page_size: int = 10, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def list( + cls, + start_time: str = None, + end_time: str = None, + model_name: str = None, + api_key_id: str = None, + region: str = None, + status: str = None, + page_no: int = 1, + page_size: int = 10, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """List async tasks. Args: @@ -195,14 +214,16 @@ def list(cls, Returns: DashScopeAPIResponse: The response data. """ - return super().list(start_time=start_time, - end_time=end_time, - model_name=model_name, - api_key_id=api_key_id, - region=region, - status=status, - page_no=page_no, - page_size=page_size, - api_key=api_key, - workspace=workspace, - **kwargs) + return super().list( + start_time=start_time, + end_time=end_time, + model_name=model_name, + api_key_id=api_key_id, + region=region, + status=status, + page_no=page_no, + page_size=page_size, + api_key=api_key, + workspace=workspace, + **kwargs, + ) diff --git a/dashscope/embeddings/batch_text_embedding_response.py b/dashscope/embeddings/batch_text_embedding_response.py index 559abf4..405ee10 100644 --- a/dashscope/embeddings/batch_text_embedding_response.py +++ b/dashscope/embeddings/batch_text_embedding_response.py @@ -1,11 +1,14 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from http import HTTPStatus from attr import dataclass -from dashscope.api_entities.dashscope_response import (DashScopeAPIResponse, - DictMixin) +from dashscope.api_entities.dashscope_response import ( + DashScopeAPIResponse, + DictMixin, +) @dataclass(init=False) @@ -14,23 +17,27 @@ class BatchTextEmbeddingOutput(DictMixin): task_status: str url: str - def __init__(self, - task_id: str, - task_status: str, - url: str = None, - **kwargs): - super().__init__(self, - task_id=task_id, - task_status=task_status, - url=url, - **kwargs) + def __init__( + self, + task_id: str, + task_status: str, + url: str = None, + **kwargs, + ): + super().__init__( + self, + task_id=task_id, + task_status=task_status, + url=url, + **kwargs, + ) @dataclass(init=False) class BatchTextEmbeddingUsage(DictMixin): total_tokens: int - def __init__(self, total_tokens: int=None, **kwargs): + def __init__(self, total_tokens: int = None, **kwargs): super().__init__(total_tokens=total_tokens, **kwargs) @@ -55,11 +62,13 @@ def from_api_response(api_response: DashScopeAPIResponse): code=api_response.code, message=api_response.message, output=output, - usage=usage) + usage=usage, + ) else: return BatchTextEmbeddingResponse( status_code=api_response.status_code, request_id=api_response.request_id, code=api_response.code, - message=api_response.message) + message=api_response.message, + ) diff --git a/dashscope/embeddings/multimodal_embedding.py b/dashscope/embeddings/multimodal_embedding.py index 8d9b414..1247fce 100644 --- a/dashscope/embeddings/multimodal_embedding.py +++ b/dashscope/embeddings/multimodal_embedding.py @@ -1,10 +1,13 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from dataclasses import dataclass from typing import List -from dashscope.api_entities.dashscope_response import (DashScopeAPIResponse, - DictMixin) +from dashscope.api_entities.dashscope_response import ( + DashScopeAPIResponse, + DictMixin, +) from dashscope.client.base_api import BaseApi, BaseAioApi from dashscope.common.error import InputRequired, ModelRequired from dashscope.common.utils import _get_task_group_and_task @@ -47,24 +50,27 @@ def __init__(self, audio: str, factor: float, **kwargs): class MultiModalEmbedding(BaseApi): - task = 'multimodal-embedding' + task = "multimodal-embedding" class Models: - multimodal_embedding_one_peace_v1 = 'multimodal-embedding-one-peace-v1' + multimodal_embedding_one_peace_v1 = "multimodal-embedding-one-peace-v1" @classmethod - def call(cls, - model: str, - input: List[MultiModalEmbeddingItemBase], - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def call( # type: ignore[override] + cls, + model: str, + # pylint: disable=redefined-builtin + input: List[MultiModalEmbeddingItemBase], + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Get embedding multimodal contents.. Args: model (str): The embedding model name. - input (List[MultiModalEmbeddingElement]): The embedding elements, - every element include data, modal, factor field. + input (List[MultiModalEmbeddingElement]): The embedding + elements, every element include data, modal, factor field. workspace (str): The dashscope workspace id. **kwargs: auto_truncation(bool, `optional`): Automatically truncate @@ -75,30 +81,40 @@ def call(cls, DashScopeAPIResponse: The embedding result. """ if input is None or not input: - raise InputRequired('prompt is required!') + raise InputRequired("prompt is required!") if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") embedding_input = {} - has_upload = cls._preprocess_message_inputs(model, input, api_key) + has_upload = cls._preprocess_message_inputs( + model, + input, # type: ignore[arg-type] + api_key, + ) # noqa: E501 if has_upload: - headers = kwargs.pop('headers', {}) - headers['X-DashScope-OssResourceResolve'] = 'enable' - kwargs['headers'] = headers - embedding_input['contents'] = input - kwargs.pop('stream', False) # not support streaming output. + headers = kwargs.pop("headers", {}) + headers["X-DashScope-OssResourceResolve"] = "enable" + kwargs["headers"] = headers + embedding_input["contents"] = input + kwargs.pop("stream", False) # not support streaming output. task_group, function = _get_task_group_and_task(__name__) - return super().call(model=model, - input=embedding_input, - task_group=task_group, - task=MultiModalEmbedding.task, - function=function, - api_key=api_key, - workspace=workspace, - **kwargs) + return super().call( + model=model, + input=embedding_input, + task_group=task_group, + task=MultiModalEmbedding.task, + function=function, + api_key=api_key, + workspace=workspace, + **kwargs, + ) @classmethod - def _preprocess_message_inputs(cls, model: str, input: List[dict], - api_key: str): + def _preprocess_message_inputs( + cls, + model: str, + input_data: List[dict], + api_key: str, + ): """preprocess following inputs input = [{'factor': 1, 'text': 'hello'}, {'factor': 2, 'audio': ''}, @@ -106,34 +122,41 @@ def _preprocess_message_inputs(cls, model: str, input: List[dict], """ has_upload = False upload_certificate = None - for elem in input: + for elem in input_data: if not isinstance(elem, (int, float, bool, str, bytes, bytearray)): is_upload, upload_certificate = preprocess_message_element( - model, elem, api_key, upload_certificate) + model, + elem, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload and not has_upload: has_upload = True return has_upload class AioMultiModalEmbedding(BaseAioApi): - task = 'multimodal-embedding' + task = "multimodal-embedding" class Models: - multimodal_embedding_one_peace_v1 = 'multimodal-embedding-one-peace-v1' + multimodal_embedding_one_peace_v1 = "multimodal-embedding-one-peace-v1" @classmethod - async def call(cls, - model: str, - input: List[MultiModalEmbeddingItemBase], - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + async def call( # type: ignore[override] + cls, + model: str, + # pylint: disable=redefined-builtin + input: List[MultiModalEmbeddingItemBase], + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Get embedding multimodal contents.. Args: model (str): The embedding model name. - input (List[MultiModalEmbeddingElement]): The embedding elements, - every element include data, modal, factor field. + input (List[MultiModalEmbeddingElement]): The embedding + elements, every element include data, modal, factor field. workspace (str): The dashscope workspace id. **kwargs: auto_truncation(bool, `optional`): Automatically truncate @@ -144,17 +167,21 @@ async def call(cls, DashScopeAPIResponse: The embedding result. """ if input is None or not input: - raise InputRequired('prompt is required!') + raise InputRequired("prompt is required!") if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") embedding_input = {} - has_upload = cls._preprocess_message_inputs(model, input, api_key) + has_upload = cls._preprocess_message_inputs( + model, + input, # type: ignore[arg-type] + api_key, + ) # noqa: E501 if has_upload: - headers = kwargs.pop('headers', {}) - headers['X-DashScope-OssResourceResolve'] = 'enable' - kwargs['headers'] = headers - embedding_input['contents'] = input - kwargs.pop('stream', False) # not support streaming output. + headers = kwargs.pop("headers", {}) + headers["X-DashScope-OssResourceResolve"] = "enable" + kwargs["headers"] = headers + embedding_input["contents"] = input + kwargs.pop("stream", False) # not support streaming output. task_group, function = _get_task_group_and_task(__name__) response = await super().call( model=model, @@ -164,12 +191,17 @@ async def call(cls, function=function, api_key=api_key, workspace=workspace, - **kwargs) + **kwargs, + ) return response @classmethod - def _preprocess_message_inputs(cls, model: str, input: List[dict], - api_key: str): + def _preprocess_message_inputs( + cls, + model: str, + input_data: List[dict], + api_key: str, + ): """preprocess following inputs input = [{'factor': 1, 'text': 'hello'}, {'factor': 2, 'audio': ''}, @@ -177,10 +209,14 @@ def _preprocess_message_inputs(cls, model: str, input: List[dict], """ has_upload = False upload_certificate = None - for elem in input: + for elem in input_data: if not isinstance(elem, (int, float, bool, str, bytes, bytearray)): is_upload, upload_certificate = preprocess_message_element( - model, elem, api_key, upload_certificate) + model, + elem, + api_key, + upload_certificate, # type: ignore[arg-type] + ) if is_upload and not has_upload: has_upload = True return has_upload diff --git a/dashscope/embeddings/text_embedding.py b/dashscope/embeddings/text_embedding.py index 3526d37..973e303 100644 --- a/dashscope/embeddings/text_embedding.py +++ b/dashscope/embeddings/text_embedding.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from typing import List, Union @@ -9,21 +10,23 @@ class TextEmbedding(BaseApi): - task = 'text-embedding' + task = "text-embedding" class Models: - text_embedding_v1 = 'text-embedding-v1' - text_embedding_v2 = 'text-embedding-v2' - text_embedding_v3 = 'text-embedding-v3' - text_embedding_v4 = 'text-embedding-v4' + text_embedding_v1 = "text-embedding-v1" + text_embedding_v2 = "text-embedding-v2" + text_embedding_v3 = "text-embedding-v3" + text_embedding_v4 = "text-embedding-v4" @classmethod - def call(cls, - model: str, - input: Union[str, List[str]], - workspace: str = None, - api_key: str = None, - **kwargs) -> DashScopeAPIResponse: + def call( # type: ignore[override] + cls, + model: str, + input: Union[str, List[str]], # pylint: disable=redefined-builtin + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Get embedding of text input. Args: @@ -44,13 +47,15 @@ def call(cls, embedding_input[TEXT_EMBEDDING_INPUT_KEY] = [input] else: embedding_input[TEXT_EMBEDDING_INPUT_KEY] = input - kwargs.pop('stream', False) # not support streaming output. + kwargs.pop("stream", False) # not support streaming output. task_group, function = _get_task_group_and_task(__name__) - return super().call(model=model, - input=embedding_input, - task_group=task_group, - task=TextEmbedding.task, - function=function, - api_key=api_key, - workspace=workspace, - **kwargs) + return super().call( + model=model, + input=embedding_input, + task_group=task_group, + task=TextEmbedding.task, + function=function, + api_key=api_key, + workspace=workspace, + **kwargs, + ) diff --git a/dashscope/files.py b/dashscope/files.py index a27f900..70f0168 100644 --- a/dashscope/files.py +++ b/dashscope/files.py @@ -1,26 +1,33 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import os from dashscope.api_entities.dashscope_response import DashScopeAPIResponse -from dashscope.client.base_api import (DeleteMixin, FileUploadMixin, GetMixin, - ListMixin) +from dashscope.client.base_api import ( + DeleteMixin, + FileUploadMixin, + GetMixin, + ListMixin, +) from dashscope.common.constants import FilePurpose from dashscope.common.error import InvalidFileFormat from dashscope.common.utils import is_validate_fine_tune_file class Files(FileUploadMixin, ListMixin, DeleteMixin, GetMixin): - SUB_PATH = 'files' + SUB_PATH = "files" @classmethod - def upload(cls, - file_path: str, - purpose: str = FilePurpose.fine_tune, - description: str = None, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def upload( # pylint: disable=arguments-renamed + cls, + file_path: str, # type: ignore[override] + purpose: str = FilePurpose.fine_tune, # type: ignore[override] + description: str = None, # type: ignore[override] + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Upload file for model fine-tune or other tasks. Args: @@ -36,23 +43,37 @@ def upload(cls, if purpose == FilePurpose.fine_tune: if not is_validate_fine_tune_file(file_path): raise InvalidFileFormat( - 'The file %s is not in valid jsonl format' % file_path) - with open(file_path, 'rb') as f: - return super().upload(files=[('files', (os.path.basename(f.name), - f, None))], - descriptions=[description] - if description is not None else None, - api_key=api_key, - workspace=workspace, - **kwargs) + f"The file {file_path} is not in valid jsonl format", + ) + with open(file_path, "rb") as f: + return super().upload( # type: ignore[return-value] + files=[ + ( + "files", + ( + os.path.basename(f.name), + f, + None, + ), + ), + ], + descriptions=[description] + if description is not None + else None, + api_key=api_key, + workspace=workspace, + **kwargs, + ) @classmethod - def list(cls, - page=1, - page_size=10, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def list( # type: ignore[override] + cls, + page=1, + page_size=10, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """List uploaded files. Args: @@ -66,18 +87,22 @@ def list(cls, Returns: DashScopeAPIResponse: The fine-tune jobs in the result. """ - return super().list(page, - page_size, - api_key, - workspace=workspace, - **kwargs) + return super().list( # type: ignore[return-value] + page, + page_size, + api_key, + workspace=workspace, + **kwargs, + ) @classmethod - def get(cls, - file_id: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def get( # type: ignore[override] + cls, + file_id: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Get the file info. Args: @@ -88,14 +113,17 @@ def get(cls, Returns: DashScopeAPIResponse: The job info """ - return super().get(file_id, api_key, workspace=workspace, **kwargs) + # type: ignore + return super().get(file_id, api_key, workspace=workspace, **kwargs) # type: ignore[return-value] # pylint: disable=line-too-long # noqa: E501 @classmethod - def delete(cls, - file_id: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def delete( # type: ignore[override] + cls, + file_id: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Delete uploaded file. Args: @@ -105,5 +133,5 @@ def delete(cls, Returns: DashScopeAPIResponse: Delete result. - """ - return super().delete(file_id, api_key, workspace=workspace, **kwargs) + """ # type: ignore + return super().delete(file_id, api_key, workspace=workspace, **kwargs) # type: ignore[return-value] # pylint: disable=line-too-long # noqa: E501 diff --git a/dashscope/io/input_output.py b/dashscope/io/input_output.py index a1a5dc1..ea6caf2 100644 --- a/dashscope/io/input_output.py +++ b/dashscope/io/input_output.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import base64 @@ -8,28 +9,35 @@ class InputResolver: - def __init__(self, - input_instance, - is_encode_binary: bool = True, - custom_type_resolver: dict = {}): + # pylint: disable=dangerous-default-value + def __init__( + self, + input_instance, + is_encode_binary: bool = True, + custom_type_resolver: dict = {}, + ): self._instance = input_instance self._is_encode_binary = is_encode_binary self._custom_type_resolver = custom_type_resolver def __next__(self): while True: - return resolve_input(self._instance, self._is_encode_binary, - self._custom_type_resolver) + return resolve_input( + self._instance, + self._is_encode_binary, + self._custom_type_resolver, + ) def __iter__(self): return self -def _is_binary_file(bytes): +def _is_binary_file(bytes): # pylint: disable=redefined-builtin # solution from: https://stackoverflow.com/questions/898669/ # how-can-i-detect-if-a-file-is-binary-non-text-in-python/7392391#7392391 - text_chars = bytearray({7, 8, 9, 10, 12, 13, 27} - | set(range(0x20, 0x100)) - {0x7f}) + text_chars = bytearray( + {7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F}, + ) if len(bytes) > 1024: temp_bytes = bytes[:1024] else: @@ -37,7 +45,12 @@ def _is_binary_file(bytes): return bool(temp_bytes.translate(None, text_chars)) -def resolve_input(input, is_encode_binary, custom_type_resolver: dict = {}): +# pylint: disable=too-many-branches,dangerous-default-value,too-many-return-statements # noqa: E501 +def resolve_input( + input, + is_encode_binary, + custom_type_resolver: dict = {}, +): # pylint: disable=redefined-builtin # noqa: E501 """Resolve input data, if same field is file, generator, we can get data. Args: @@ -61,14 +74,18 @@ def ndarray_tolist(ndar): if isinstance(input, dict): out_input = {} for key, value in input.items(): - out_input[key] = resolve_input(value, is_encode_binary, - custom_type_resolver) + out_input[key] = resolve_input( + value, + is_encode_binary, + custom_type_resolver, + ) return out_input elif isinstance(input, (list, tuple, set)): out_input = [] for item in input: out_input.append( - resolve_input(item, is_encode_binary, custom_type_resolver)) + resolve_input(item, is_encode_binary, custom_type_resolver), + ) return out_input elif isinstance(input, str): return input @@ -78,7 +95,7 @@ def ndarray_tolist(ndar): return input elif isinstance(input, (bytearray, bytes, memoryview)): if is_encode_binary: - return base64.b64encode(input).decode('ascii') + return base64.b64encode(input).decode("ascii") else: return input elif isinstance(input, io.IOBase): @@ -93,11 +110,11 @@ def ndarray_tolist(ndar): is_binary_file = _is_binary_file(content) if is_binary_file: if is_encode_binary: - return base64.b64encode(content).decode('ascii') + return base64.b64encode(content).decode("ascii") else: return content else: # split line by line. - return content.decode('utf-8').splitlines() + return content.decode("utf-8").splitlines() else: if is_encode_binary: return content.splitlines() @@ -110,5 +127,6 @@ def ndarray_tolist(ndar): elif type(input) in custom_type_resolver: return custom_type_resolver[type(input)](input) else: - raise UnsupportedDataType('Unsupported atom data type: %s' % - type(input)) + raise UnsupportedDataType( + f"Unsupported atom data type: {type(input)}", + ) diff --git a/dashscope/model.py b/dashscope/model.py index 5ac9aed..6aea734 100644 --- a/dashscope/model.py +++ b/dashscope/model.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from dashscope.api_entities.dashscope_response import DashScopeAPIResponse @@ -5,14 +6,16 @@ class Model(ListMixin, GetMixin): - SUB_PATH = 'models' + SUB_PATH = "models" @classmethod - def get(cls, - name: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def get( # type: ignore[override] + cls, + name: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Get the model information. Args: @@ -23,15 +26,23 @@ def get(cls, Returns: DashScopeAPIResponse: The model information. """ - return super().get(name, api_key, workspace=workspace, **kwargs) + # type: ignore + return super().get( # type: ignore[return-value] + name, + api_key, + workspace=workspace, + **kwargs, + ) # noqa: E501 @classmethod - def list(cls, - page=1, - page_size=10, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def list( # type: ignore[override] + cls, + page=1, + page_size=10, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """List models. Args: @@ -42,8 +53,10 @@ def list(cls, Returns: DashScopeAPIResponse: The models. """ - return super().list(api_key, - page, - page_size, - workspace=workspace, - **kwargs) + return super().list( # type: ignore[return-value] + api_key, + page, + page_size, + workspace=workspace, + **kwargs, + ) diff --git a/dashscope/models.py b/dashscope/models.py index 9f77564..764e088 100644 --- a/dashscope/models.py +++ b/dashscope/models.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from dashscope.api_entities.dashscope_response import DashScopeAPIResponse @@ -5,13 +6,15 @@ class Models(ListMixin, GetMixin): - SUB_PATH = 'models' + SUB_PATH = "models" @classmethod - def get(cls, - name: str, - api_key: str = None, - **kwargs) -> DashScopeAPIResponse: + def get( # type: ignore[override] + cls, + name: str, + api_key: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Get the model information. Args: @@ -22,14 +25,16 @@ def get(cls, Returns: DashScopeAPIResponse: The model information. """ - return super().get(name, api_key, **kwargs) + return super().get(name, api_key, **kwargs) # type: ignore @classmethod - def list(cls, - page=1, - page_size=10, - api_key: str = None, - **kwargs) -> DashScopeAPIResponse: + def list( # type: ignore[override] + cls, + page=1, + page_size=10, + api_key: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """List models. Args: @@ -40,4 +45,5 @@ def list(cls, Returns: DashScopeAPIResponse: The models. """ - return super().list(page, page_size, api_key=api_key, **kwargs) + # type: ignore + return super().list(page, page_size, api_key=api_key, **kwargs) # type: ignore[return-value] # pylint: disable=line-too-long # noqa: E501 diff --git a/dashscope/multimodal/__init__.py b/dashscope/multimodal/__init__.py index 5787fc6..540bb57 100644 --- a/dashscope/multimodal/__init__.py +++ b/dashscope/multimodal/__init__.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from .tingwu import tingwu @@ -6,15 +7,89 @@ from .multimodal_dialog import MultiModalDialog, MultiModalCallback from .dialog_state import DialogState -from .multimodal_constants import * -from .multimodal_request_params import * +from .multimodal_constants import ( + RequestToRespondType, + RESPONSE_NAME_TASK_STARTED, + RESPONSE_NAME_RESULT_GENERATED, + RESPONSE_NAME_TASK_FINISHED, + RESPONSE_NAME_TASK_FAILED, + RESPONSE_NAME_STARTED, + RESPONSE_NAME_STOPPED, + RESPONSE_NAME_STATE_CHANGED, + RESPONSE_NAME_REQUEST_ACCEPTED, + RESPONSE_NAME_SPEECH_STARTED, + RESPONSE_NAME_SPEECH_ENDED, + RESPONSE_NAME_RESPONDING_STARTED, + RESPONSE_NAME_RESPONDING_ENDED, + RESPONSE_NAME_SPEECH_CONTENT, + RESPONSE_NAME_RESPONDING_CONTENT, + RESPONSE_NAME_ERROR, + RESPONSE_NAME_HEART_BEAT, +) +from .multimodal_request_params import ( + DashHeader, + DashPayloadParameters, + DashPayloadInput, + DashPayload, + RequestBodyInput, + AsrPostProcessing, + ReplaceWord, + Upstream, + Downstream, + DialogAttributes, + Locations, + Network, + Device, + ClientInfo, + BizParams, + RequestParameters, + RequestToRespondParameters, + RequestToRespondBodyInput, + get_random_uuid, +) __all__ = [ - 'tingwu', - 'TingWu', - 'TingWuRealtime', - 'TingWuRealtimeCallback', - 'MultiModalDialog', - 'MultiModalCallback', - 'DialogState' -] \ No newline at end of file + "tingwu", + "TingWu", + "TingWuRealtime", + "TingWuRealtimeCallback", + "MultiModalDialog", + "MultiModalCallback", + "DialogState", + "RequestToRespondType", + "RESPONSE_NAME_TASK_STARTED", + "RESPONSE_NAME_RESULT_GENERATED", + "RESPONSE_NAME_TASK_FINISHED", + "RESPONSE_NAME_TASK_FAILED", + "RESPONSE_NAME_STARTED", + "RESPONSE_NAME_STOPPED", + "RESPONSE_NAME_STATE_CHANGED", + "RESPONSE_NAME_REQUEST_ACCEPTED", + "RESPONSE_NAME_SPEECH_STARTED", + "RESPONSE_NAME_SPEECH_ENDED", + "RESPONSE_NAME_RESPONDING_STARTED", + "RESPONSE_NAME_RESPONDING_ENDED", + "RESPONSE_NAME_SPEECH_CONTENT", + "RESPONSE_NAME_RESPONDING_CONTENT", + "RESPONSE_NAME_ERROR", + "RESPONSE_NAME_HEART_BEAT", + "DashHeader", + "DashPayloadParameters", + "DashPayloadInput", + "DashPayload", + "RequestBodyInput", + "AsrPostProcessing", + "ReplaceWord", + "Upstream", + "Downstream", + "DialogAttributes", + "Locations", + "Network", + "Device", + "ClientInfo", + "BizParams", + "RequestParameters", + "RequestToRespondParameters", + "RequestToRespondBodyInput", + "get_random_uuid", +] diff --git a/dashscope/multimodal/dialog_state.py b/dashscope/multimodal/dialog_state.py index 6c57d08..1bead25 100644 --- a/dashscope/multimodal/dialog_state.py +++ b/dashscope/multimodal/dialog_state.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # dialog_state.py from enum import Enum @@ -13,10 +14,11 @@ class DialogState(Enum): THINKING (str): 表示机器人正在思考。 RESPONDING (str): 表示机器人正在生成或回复中。 """ - IDLE = 'Idle' - LISTENING = 'Listening' - THINKING = 'Thinking' - RESPONDING = 'Responding' + + IDLE = "Idle" + LISTENING = "Listening" + THINKING = "Thinking" + RESPONDING = "Responding" class StateMachine: diff --git a/dashscope/multimodal/multimodal_constants.py b/dashscope/multimodal/multimodal_constants.py index 46eabc6..12a2cb0 100644 --- a/dashscope/multimodal/multimodal_constants.py +++ b/dashscope/multimodal/multimodal_constants.py @@ -1,11 +1,13 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. # -*- coding: utf-8 -*- # multimodal conversation request directive + class RequestToRespondType: - TRANSCRIPT = 'transcript' - PROMPT = 'prompt' + TRANSCRIPT = "transcript" + PROMPT = "prompt" # multimodal conversation response directive @@ -20,9 +22,11 @@ class RequestToRespondType: RESPONSE_NAME_REQUEST_ACCEPTED = "RequestAccepted" RESPONSE_NAME_SPEECH_STARTED = "SpeechStarted" RESPONSE_NAME_SPEECH_ENDED = "SpeechEnded" # 服务端检测到asr语音尾点时下发此事件,可选事件 -RESPONSE_NAME_RESPONDING_STARTED = "RespondingStarted" # AI语音应答开始,sdk要准备接收服务端下发的语音数据 +RESPONSE_NAME_RESPONDING_STARTED = ( + "RespondingStarted" # AI语音应答开始,sdk要准备接收服务端下发的语音数据 +) RESPONSE_NAME_RESPONDING_ENDED = "RespondingEnded" # AI语音应答结束 RESPONSE_NAME_SPEECH_CONTENT = "SpeechContent" # 用户语音识别出的文本,流式全量输出 RESPONSE_NAME_RESPONDING_CONTENT = "RespondingContent" # 统对外输出的文本,流式全量输出 RESPONSE_NAME_ERROR = "Error" # 服务端对话中报错 -RESPONSE_NAME_HEART_BEAT = "HeartBeat" # 心跳消息 \ No newline at end of file +RESPONSE_NAME_HEART_BEAT = "HeartBeat" # 心跳消息 diff --git a/dashscope/multimodal/multimodal_dialog.py b/dashscope/multimodal/multimodal_dialog.py index 99dc269..10a40f3 100644 --- a/dashscope/multimodal/multimodal_dialog.py +++ b/dashscope/multimodal/multimodal_dialog.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import json import platform import time @@ -10,9 +11,29 @@ from dashscope.common.logging import logger from dashscope.common.error import InputRequired from dashscope.multimodal import dialog_state -from dashscope.multimodal.multimodal_constants import * -from dashscope.multimodal.multimodal_request_params import RequestParameters, get_random_uuid, DashHeader, \ - RequestBodyInput, DashPayload, RequestToRespondParameters, RequestToRespondBodyInput +from dashscope.multimodal.multimodal_constants import ( + RESPONSE_NAME_STARTED, + RESPONSE_NAME_STOPPED, + RESPONSE_NAME_STATE_CHANGED, + RESPONSE_NAME_REQUEST_ACCEPTED, + RESPONSE_NAME_SPEECH_STARTED, + RESPONSE_NAME_SPEECH_ENDED, + RESPONSE_NAME_RESPONDING_STARTED, + RESPONSE_NAME_RESPONDING_ENDED, + RESPONSE_NAME_SPEECH_CONTENT, + RESPONSE_NAME_RESPONDING_CONTENT, + RESPONSE_NAME_ERROR, + RESPONSE_NAME_HEART_BEAT, +) +from dashscope.multimodal.multimodal_request_params import ( + RequestParameters, + get_random_uuid, + DashHeader, + RequestBodyInput, + DashPayload, + RequestToRespondParameters, + RequestToRespondBodyInput, +) from dashscope.protocol.websocket import ActionType @@ -27,21 +48,18 @@ def on_started(self, dialog_id: str) -> None: :param dialog_id: 回调对话ID """ - pass def on_stopped(self) -> None: """ 通知对话停止 """ - pass - def on_state_changed(self, state: 'dialog_state.DialogState') -> None: + def on_state_changed(self, state: "dialog_state.DialogState") -> None: """ 对话状态改变 :param state: 新的对话状态 """ - pass def on_speech_audio_data(self, data: bytes) -> None: """ @@ -49,7 +67,6 @@ def on_speech_audio_data(self, data: bytes) -> None: :param data: 音频数据 """ - pass def on_error(self, error) -> None: """ @@ -57,37 +74,31 @@ def on_error(self, error) -> None: :param error: 错误信息 """ - pass def on_connected(self) -> None: """ 成功连接到服务器后调用此方法。 """ - pass def on_responding_started(self): """ 回复开始回调 """ - pass def on_responding_ended(self, payload): """ 回复结束 """ - pass def on_speech_started(self): """ 检测到语音输入结束 """ - pass def on_speech_ended(self): """ 检测到语音输入结束 """ - pass def on_speech_content(self, payload): """ @@ -95,7 +106,6 @@ def on_speech_content(self, payload): :param payload: text """ - pass def on_responding_content(self, payload): """ @@ -103,13 +113,11 @@ def on_responding_content(self, payload): :param payload: text """ - pass def on_request_accepted(self): """ 打断请求被接受。 """ - pass def on_close(self, close_status_code, close_msg): """ @@ -118,7 +126,6 @@ def on_close(self, close_status_code, close_msg): :param close_status_code: 关闭状态码 :param close_msg: 关闭消息 """ - pass class MultiModalDialog: @@ -126,31 +133,32 @@ class MultiModalDialog: 用于管理WebSocket连接以进行语音聊天的服务类。 """ - def __init__(self, - app_id: str, - request_params: RequestParameters, - multimodal_callback: MultiModalCallback, - workspace_id: str = None, - url: str = None, - api_key: str = None, - dialog_id: str = None, - model: str = None - ): - """ - 创建一个语音对话会话。 - - 此方法用于初始化一个新的voice_chat会话,设置必要的参数以准备开始与模型的交互。 - :param workspace_id: 客户的workspace_id 主工作空间id,非必填字段 - :param app_id: 客户在管控台创建的应用id,可以根据值规律确定使用哪个对话系统 - :param request_params: 请求参数集合 - :param url: (str) API的URL地址。 - :param multimodal_callback: (MultimodalCallback) 回调对象,用于处理来自服务器的消息。 - :param api_key: (str) 应用程序接入的唯一key - :param dialog_id:对话id,如果传入表示承接上下文继续聊 - :param model: 模型 + def __init__( + self, + app_id: str, + request_params: RequestParameters, + multimodal_callback: MultiModalCallback, + workspace_id: str = None, + url: str = None, + api_key: str = None, + dialog_id: str = None, + model: str = None, + ): + """ + 创建一个语音对话会话。 + + 此方法用于初始化一个新的voice_chat会话,设置必要的参数以准备开始与模型的交互。 + :param workspace_id: 客户的workspace_id 主工作空间id,非必填字段 + :param app_id: 客户在管控台创建的应用id,可以根据值规律确定使用哪个对话系统 + :param request_params: 请求参数集合 + :param url: (str) API的URL地址。 + :param multimodal_callback: (MultimodalCallback) 回调对象,用于处理来自服务器的消息。 + :param api_key: (str) 应用程序接入的唯一key + :param dialog_id:对话id,如果传入表示承接上下文继续聊 + :param model: 模型 """ if request_params is None: - raise InputRequired('request_params is required!') + raise InputRequired("request_params is required!") if url is None: url = dashscope.base_websocket_api_url if api_key is None: @@ -169,24 +177,40 @@ def __init__(self, self.app_id = app_id self.dialog_id = dialog_id self.dialog_state = dialog_state.StateMachine() - self.response = _Response(self.dialog_state, self._callback, self.close) # 传递 self.close 作为回调 - - def _on_message(self, ws, message): + self.response = _Response( + self.dialog_state, + self._callback, + self.close, + ) # 传递 self.close 作为回调 + + def _on_message( # pylint: disable=unused-argument + self, + ws, + message, + ): logger.debug(f"<<<<<<< Received message: {message}") if isinstance(message, str): self.response.handle_text_response(message) elif isinstance(message, (bytes, bytearray)): self.response.handle_binary_response(message) - def _on_error(self, ws, error): + def _on_error(self, ws, error): # pylint: disable=unused-argument logger.error(f"Error: {error}") if self._callback: self._callback.on_error(error) - def _on_close(self, ws, close_status_code, close_msg): + def _on_close( # pylint: disable=unused-argument + self, + ws, + close_status_code, + close_msg, + ): try: logger.debug( - "WebSocket connection closed with status {} and message {}".format(close_status_code, close_msg)) + "WebSocket connection closed with status %s and message %s", # noqa: E501 + close_status_code, + close_msg, + ) if close_status_code is None: close_status_code = 1000 if close_msg is None: @@ -195,7 +219,7 @@ def _on_close(self, ws, close_status_code, close_msg): except Exception as e: logger.error(f"Error: {e}") - def _on_open(self, ws): + def _on_open(self, ws): # pylint: disable=unused-argument self._callback.on_connected() # def _on_pong(self, _): @@ -212,11 +236,18 @@ def start(self, dialog_id, enable_voice_detection=False, task_id=None): self._voice_detection = enable_voice_detection self._connect(self.api_key) logger.debug("connected with server.") - self._send_start_request(dialog_id, self.request_params, task_id=task_id) + self._send_start_request( + dialog_id, + self.request_params, + task_id=task_id, + ) def start_speech(self): """开始上传语音数据""" - _send_speech_json = self.request.generate_common_direction_request("SendSpeech", self.dialog_id) + _send_speech_json = self.request.generate_common_direction_request( + "SendSpeech", + self.dialog_id, + ) self._send_text_frame(_send_speech_json) def send_audio_data(self, speech_data: bytes): @@ -225,23 +256,34 @@ def send_audio_data(self, speech_data: bytes): def stop_speech(self): """停止上传语音数据""" - _send_speech_json = self.request.generate_common_direction_request("StopSpeech", self.dialog_id) + _send_speech_json = self.request.generate_common_direction_request( + "StopSpeech", + self.dialog_id, + ) self._send_text_frame(_send_speech_json) def interrupt(self): """请求服务端开始说话""" - _send_speech_json = self.request.generate_common_direction_request("RequestToSpeak", self.dialog_id) + _send_speech_json = self.request.generate_common_direction_request( + "RequestToSpeak", + self.dialog_id, + ) self._send_text_frame(_send_speech_json) - def request_to_respond(self, - request_type: str, - text: str, - parameters: RequestToRespondParameters = None): + def request_to_respond( + self, + request_type: str, + text: str, + parameters: RequestToRespondParameters = None, + ): """请求服务端直接文本合成语音""" - _send_speech_json = self.request.generate_request_to_response_json(direction_name="RequestToRespond", - dialog_id=self.dialog_id, - request_type=request_type, text=text, - parameters=parameters) + _send_speech_json = self.request.generate_request_to_response_json( + direction_name="RequestToRespond", + dialog_id=self.dialog_id, + request_type=request_type, + text=text, + parameters=parameters, + ) self._send_text_frame(_send_speech_json) @abstractmethod @@ -251,29 +293,45 @@ def request_to_respond_prompt(self, text): def local_responding_started(self): """本地tts播放开始""" - _send_speech_json = self.request.generate_common_direction_request("LocalRespondingStarted", self.dialog_id) + _send_speech_json = self.request.generate_common_direction_request( + "LocalRespondingStarted", + self.dialog_id, + ) self._send_text_frame(_send_speech_json) def local_responding_ended(self): """本地tts播放结束""" - _send_speech_json = self.request.generate_common_direction_request("LocalRespondingEnded", self.dialog_id) + _send_speech_json = self.request.generate_common_direction_request( + "LocalRespondingEnded", + self.dialog_id, + ) self._send_text_frame(_send_speech_json) def send_heart_beat(self): """发送心跳""" - _send_speech_json = self.request.generate_common_direction_request("HeartBeat", self.dialog_id) + _send_speech_json = self.request.generate_common_direction_request( + "HeartBeat", + self.dialog_id, + ) self._send_text_frame(_send_speech_json) def update_info(self, parameters: RequestToRespondParameters = None): """更新信息""" - _send_speech_json = self.request.generate_update_info_json(direction_name="UpdateInfo", dialog_id=self.dialog_id, parameters=parameters) + _send_speech_json = self.request.generate_update_info_json( + direction_name="UpdateInfo", + dialog_id=self.dialog_id, + parameters=parameters, + ) self._send_text_frame(_send_speech_json) def stop(self): if self.ws is None or not self.ws.sock or not self.ws.sock.connected: self._callback.on_close(1001, "websocket is not connected") return - _send_speech_json = self.request.generate_stop_request("Stop", self.dialog_id) + _send_speech_json = self.request.generate_stop_request( + "Stop", + self.dialog_id, + ) self._send_text_frame(_send_speech_json) def get_dialog_state(self) -> dialog_state.DialogState: @@ -283,9 +341,14 @@ def get_conversation_mode(self) -> str: """get mode of conversation: support tap2talk/push2talk/duplex""" return self.request_params.upstream.mode - """内部方法""" + """内部方法""" # pylint: disable=pointless-string-statement - def _send_start_request(self, dialog_id: str, request_params: RequestParameters, task_id: str = None): + def _send_start_request( + self, + dialog_id: str, + request_params: RequestParameters, + task_id: str = None, + ): """发送'Start'请求""" _start_json = self.request.generate_start_request( workspace_id=self.workspace_id, @@ -294,7 +357,7 @@ def _send_start_request(self, dialog_id: str, request_params: RequestParameters, app_id=self.app_id, request_params=request_params, model=self.model, - task_id=task_id + task_id=task_id, ) # send start request self._send_text_frame(_start_json) @@ -304,11 +367,14 @@ def _run_forever(self): def _connect(self, api_key: str): """初始化WebSocket连接并发送启动请求。""" - self.ws = websocket.WebSocketApp(self.url, header=self.request.get_websocket_header(api_key), - on_open=self._on_open, - on_message=self._on_message, - on_error=self._on_error, - on_close=self._on_close) + self.ws = websocket.WebSocketApp( + self.url, + header=self.request.get_websocket_header(api_key), + on_open=self._on_open, + on_message=self._on_message, + on_error=self._on_error, + on_close=self._on_close, + ) self.thread = threading.Thread(target=self._run_forever) self.ws.ping_interval = 3 self.thread.daemon = True @@ -325,11 +391,14 @@ def _wait_for_connection(self): """等待WebSocket连接建立""" timeout = 5 start_time = time.time() - while not (self.ws.sock and self.ws.sock.connected) and (time.time() - start_time) < timeout: + while ( + not (self.ws.sock and self.ws.sock.connected) + and (time.time() - start_time) < timeout + ): time.sleep(0.1) # 短暂休眠,避免密集轮询 def _send_text_frame(self, text: str): - logger.info('>>>>>> send text frame : %s' % text) + logger.info(">>>>>> send text frame : %s", text) self.ws.send(text, websocket.ABNF.OPCODE_TEXT) def __send_binary_frame(self, binary: bytes): @@ -369,28 +438,29 @@ def __init__(self): self.workspace_id = None def get_websocket_header(self, api_key): - ua = 'dashscope/%s; python/%s; platform/%s; processor/%s' % ( - '1.18.0', # dashscope version - platform.python_version(), - platform.platform(), - platform.processor(), + ua = ( + f"dashscope/1.18.0; python/{platform.python_version()}; " + f"platform/{platform.platform()}; " + f"processor/{platform.processor()}" ) self.ws_headers = { "User-Agent": ua, "Authorization": f"bearer {api_key}", - "Accept": "application/json" + "Accept": "application/json", } - logger.info('websocket header: {}'.format(self.ws_headers)) + logger.info("websocket header: %s", self.ws_headers) return self.ws_headers - def generate_start_request(self, direction_name: str, - dialog_id: str, - app_id: str, - request_params: RequestParameters, - model: str = None, - workspace_id: str = None, - task_id: str = None - ) -> str: + def generate_start_request( + self, + direction_name: str, + dialog_id: str, + app_id: str, + request_params: RequestParameters, + model: str = None, + workspace_id: str = None, + task_id: str = None, + ) -> str: """ 构建语音聊天服务的启动请求数据. :param app_id: 管控台应用id @@ -404,16 +474,26 @@ def generate_start_request(self, direction_name: str, """ self.task_id = task_id self._get_dash_request_header(ActionType.START) - self._get_dash_request_payload(direction_name, dialog_id, app_id, workspace_id=workspace_id, - request_params=request_params, model=model) + self._get_dash_request_payload( + direction_name, + dialog_id, + app_id, + workspace_id=workspace_id, + request_params=request_params, + model=model, + ) cmd = { "header": self.header, - "payload": self.payload + "payload": self.payload, } return json.dumps(cmd) - def generate_common_direction_request(self, direction_name: str, dialog_id: str) -> str: + def generate_common_direction_request( + self, + direction_name: str, + dialog_id: str, + ) -> str: """ 构建语音聊天服务的命令请求数据. :param direction_name: 命令. @@ -424,11 +504,15 @@ def generate_common_direction_request(self, direction_name: str, dialog_id: str) self._get_dash_request_payload(direction_name, dialog_id, self.app_id) cmd = { "header": self.header, - "payload": self.payload + "payload": self.payload, } return json.dumps(cmd) - def generate_stop_request(self, direction_name: str, dialog_id: str) -> str: + def generate_stop_request( + self, + direction_name: str, + dialog_id: str, + ) -> str: """ 构建语音聊天服务的启动请求数据. :param direction_name:指令名称 @@ -440,17 +524,23 @@ def generate_stop_request(self, direction_name: str, dialog_id: str) -> str: cmd = { "header": self.header, - "payload": self.payload + "payload": self.payload, } return json.dumps(cmd) - def generate_request_to_response_json(self, direction_name: str, dialog_id: str, request_type: str, text: str, - parameters: RequestToRespondParameters = None) -> str: + def generate_request_to_response_json( + self, + direction_name: str, + dialog_id: str, + request_type: str, + text: str, + parameters: RequestToRespondParameters = None, + ) -> str: """ 构建语音聊天服务的命令请求数据. :param direction_name: 命令. :param dialog_id: 对话ID. - :param request_type: 服务应该采取的交互类型,transcript 表示直接把文本转语音,prompt 表示把文本送大模型回答 + :param request_type: 服务应该采取的交互类型,transcript 表示直接把文本转语音,prompt 表示把文本送大模型回答 # noqa: E501 :param text: 文本. :param parameters: 命令请求body中的parameters :return: 命令请求字典. @@ -462,18 +552,28 @@ def generate_request_to_response_json(self, direction_name: str, dialog_id: str, directive=direction_name, dialog_id=dialog_id, type_=request_type, - text=text + text=text, ) - self._get_dash_request_payload(direction_name, dialog_id, self.app_id, request_params=parameters, - custom_input=custom_input) + self._get_dash_request_payload( + direction_name, + dialog_id, + self.app_id, + request_params=parameters, # type: ignore[arg-type] + custom_input=custom_input, + ) cmd = { "header": self.header, - "payload": self.payload + "payload": self.payload, } return json.dumps(cmd) - def generate_update_info_json(self, direction_name: str, dialog_id: str,parameters: RequestToRespondParameters = None) -> str: + def generate_update_info_json( + self, + direction_name: str, + dialog_id: str, + parameters: RequestToRespondParameters = None, + ) -> str: """ 构建语音聊天服务的命令请求数据. :param direction_name: 命令. @@ -488,26 +588,38 @@ def generate_update_info_json(self, direction_name: str, dialog_id: str,paramete dialog_id=dialog_id, ) - self._get_dash_request_payload(direction_name, dialog_id, self.app_id, request_params=parameters, - custom_input=custom_input) + self._get_dash_request_payload( + direction_name, + dialog_id, + self.app_id, + request_params=parameters, # type: ignore[arg-type] + custom_input=custom_input, + ) cmd = { "header": self.header, - "payload": self.payload + "payload": self.payload, } return json.dumps(cmd) def _get_dash_request_header(self, action: str): """ 构建多模对话请求的请求协议Header - :param action: ActionType 百炼协议action 支持:run-task, continue-task, finish-task + :param action: ActionType 百炼协议action 支持:run-task, continue-task, finish-task # noqa: E501 """ if self.task_id is None: self.task_id = get_random_uuid() self.header = DashHeader(action=action, task_id=self.task_id).to_dict() - def _get_dash_request_payload(self, direction_name: str, - dialog_id: str, app_id: str, workspace_id: str = None, - request_params: RequestParameters = None, custom_input=None, model: str = None): + def _get_dash_request_payload( + self, + direction_name: str, + dialog_id: str, + app_id: str, + workspace_id: str = None, + request_params: RequestParameters = None, + custom_input=None, + model: str = None, + ): """ 构建多模对话请求的请求协议payload :param direction_name: 对话协议内部的指令名称 @@ -518,46 +630,64 @@ def _get_dash_request_payload(self, direction_name: str, :param model: 模型 """ if custom_input is not None: - input = custom_input + input = custom_input # pylint: disable=redefined-builtin else: input = RequestBodyInput( workspace_id=workspace_id, app_id=app_id, directive=direction_name, - dialog_id=dialog_id + dialog_id=dialog_id, ) self.payload = DashPayload( model=model, input=input, - parameters=request_params + parameters=request_params, ).to_dict() class _Response: - def __init__(self, state: dialog_state.StateMachine, callback: MultiModalCallback, close_callback=None): + def __init__( + self, + state: dialog_state.StateMachine, + callback: MultiModalCallback, + close_callback=None, + ): super().__init__() self.dialog_id = None # 对话ID. self.dialog_state = state self._callback = callback self._close_callback = close_callback # 保存关闭回调函数 + # pylint: disable=inconsistent-return-statements def handle_text_response(self, response_json: str): """ 处理语音聊天服务的响应数据. :param response_json: 从服务接收到的原始JSON字符串响应。 """ - logger.info("<<<<<< server response: %s" % response_json) + logger.info("<<<<<< server response: %s", response_json) try: # 尝试将消息解析为JSON json_data = json.loads(response_json) - if "status_code" in json_data["header"] and json_data["header"]["status_code"] != 200: - logger.error("Server returned invalid message: %s" % response_json) + if ( + "status_code" in json_data["header"] + and json_data["header"]["status_code"] != 200 + ): + logger.error( + "Server returned invalid message: %s", + response_json, + ) if self._callback: self._callback.on_error(response_json) return - if "event" in json_data["header"] and json_data["header"]["event"] == "task-failed": - logger.error("Server returned invalid message: %s" % response_json) + if ( + "event" in json_data["header"] + and json_data["header"]["event"] == "task-failed" + ): + logger.error( + "Server returned invalid message: %s", + response_json, + ) if self._callback: self._callback.on_error(response_json) return None @@ -565,14 +695,21 @@ def handle_text_response(self, response_json: str): payload = json_data["payload"] if "output" in payload and payload["output"] is not None: response_event = payload["output"]["event"] - logger.info("Server response event: %s" % response_event) - self._handle_text_response_in_conversation(response_event=response_event, response_json=json_data) + logger.info("Server response event: %s", response_event) + self._handle_text_response_in_conversation( + response_event=response_event, + response_json=json_data, + ) del json_data except json.JSONDecodeError: logger.error("Failed to parse message as JSON.") - def _handle_text_response_in_conversation(self, response_event: str, response_json: dict): + def _handle_text_response_in_conversation( + self, + response_event: str, + response_json: dict, + ): # pylint: disable=too-many-branches payload = response_json["payload"] try: if response_event == RESPONSE_NAME_STARTED: @@ -581,7 +718,10 @@ def _handle_text_response_in_conversation(self, response_event: str, response_js self._handle_stopped() elif response_event == RESPONSE_NAME_STATE_CHANGED: self._handle_state_changed(payload["output"]["state"]) - logger.debug("service response change state: %s" % payload["output"]["state"]) + logger.debug( + "service response change state: %s", + payload["output"]["state"], + ) elif response_event == RESPONSE_NAME_REQUEST_ACCEPTED: self._handle_request_accepted() elif response_event == RESPONSE_NAME_SPEECH_STARTED: @@ -601,7 +741,7 @@ def _handle_text_response_in_conversation(self, response_event: str, response_js elif response_event == RESPONSE_NAME_HEART_BEAT: logger.debug("Server response heart beat") else: - logger.error("Unknown response name: {}", response_event) + logger.error("Unknown response name: %s", response_event) except json.JSONDecodeError: logger.error("Failed to parse message as JSON.") @@ -614,7 +754,7 @@ def _handle_request_accepted(self): def _handle_started(self, payload: dict): self.dialog_id = payload["dialog_id"] - self._callback.on_started(self.dialog_id) + self._callback.on_started(self.dialog_id) # type: ignore[arg-type] def _handle_stopped(self): self._callback.on_stopped() diff --git a/dashscope/multimodal/multimodal_request_params.py b/dashscope/multimodal/multimodal_request_params.py index 99c94d3..43eeb85 100644 --- a/dashscope/multimodal/multimodal_request_params.py +++ b/dashscope/multimodal/multimodal_request_params.py @@ -1,4 +1,5 @@ -from dataclasses import dataclass, field, asdict +# -*- coding: utf-8 -*- +from dataclasses import dataclass, field import uuid @@ -18,7 +19,7 @@ def to_dict(self): "action": self.action, "task_id": self.task_id, "request_id": self.task_id, - "streaming": self.streaming + "streaming": self.streaming, } @@ -50,9 +51,11 @@ def to_dict(self): } if self.parameters is not None: + # pylint: disable=assignment-from-no-return payload["parameters"] = self.parameters.to_dict() if self.input is not None: + # pylint: disable=assignment-from-no-return payload["input"] = self.input.to_dict() return payload @@ -70,11 +73,13 @@ def to_dict(self): "workspace_id": self.workspace_id, "app_id": self.app_id, "directive": self.directive, - "dialog_id": self.dialog_id + "dialog_id": self.dialog_id, } + + @dataclass class AsrPostProcessing: - replace_words: list = field(default=None) + replace_words: list = field(default=None) # type: ignore[arg-type] def to_dict(self): if self.replace_words is None: @@ -82,9 +87,10 @@ def to_dict(self): if len(self.replace_words) == 0: return None return { - "replace_words": [word.to_dict() for word in self.replace_words] + "replace_words": [word.to_dict() for word in self.replace_words], } + @dataclass class ReplaceWord: source: str = field(default=None) @@ -95,19 +101,23 @@ def to_dict(self): return { "source": self.source, "target": self.target, - "match_mode": self.match_mode + "match_mode": self.match_mode, } + @dataclass class Upstream: """struct for upstream""" + audio_format: str = field(default="pcm") # 上行语音格式,默认pcm.支持pcm/opus - type: str = field(default="AudioOnly") # 上行类型:AudioOnly 仅语音通话; AudioAndVideo 上传视频 + type: str = field( + default="AudioOnly", + ) # 上行类型:AudioOnly 仅语音通话; AudioAndVideo 上传视频 mode: str = field(default="tap2talk") # 客户端交互模式 push2talk/tap2talk/duplex sample_rate: int = field(default=16000) # 音频采样率 vocabulary_id: str = field(default=None) asr_post_processing: AsrPostProcessing = field(default=None) - pass_through_params: dict = field(default=None) + pass_through_params: dict = field(default=None) # type: ignore[arg-type] def to_dict(self): upstream: dict = { @@ -118,7 +128,9 @@ def to_dict(self): "vocabulary_id": self.vocabulary_id, } if self.asr_post_processing is not None: - upstream["asr_post_processing"] = self.asr_post_processing.to_dict() + upstream[ + "asr_post_processing" + ] = self.asr_post_processing.to_dict() if self.pass_through_params is not None: upstream.update(self.pass_through_params) @@ -134,12 +146,12 @@ class Downstream: sample_rate: int = field(default=0) # 语音音色 # 合成音频采样率 intermediate_text: str = field(default="transcript") # 控制返回给用户那些中间文本: debug: bool = field(default=False) # 控制是否返回debug信息 - # type_: str = field(default="Audio", metadata={"alias": "type"}) # 下行类型:Text:不需要下发语音;Audio:输出语音,默认值 + # type_: str = field(default="Audio", metadata={"alias": "type"}) # 下行类型:Text:不需要下发语音;Audio:输出语音,默认值 # noqa: E501 # pylint: disable=line-too-long audio_format: str = field(default="pcm") # 下行语音格式,默认pcm 。支持pcm/mp3 volume: int = field(default=50) # 语音音量 0-100 pitch_rate: int = field(default=100) # 语音语调 50-200 speech_rate: int = field(default=100) # 语音语速 50-200 - pass_through_params: dict = field(default=None) + pass_through_params: dict = field(default=None) # type: ignore[arg-type] def to_dict(self): stream: dict = { @@ -149,7 +161,7 @@ def to_dict(self): "audio_format": self.audio_format, "volume": self.volume, "pitch_rate": self.pitch_rate, - "speech_rate": self.speech_rate + "speech_rate": self.speech_rate, } if self.voice != "": stream["voice"] = self.voice @@ -170,7 +182,7 @@ def to_dict(self): return { "agent_id": self.agent_id, "prompt": self.prompt, - "vocabulary_id": self.vocabulary_id + "vocabulary_id": self.vocabulary_id, } @@ -184,7 +196,7 @@ def to_dict(self): return { "city_name": self.city_name, "latitude": self.latitude, - "longitude": self.longitude + "longitude": self.longitude, } @@ -194,7 +206,7 @@ class Network: def to_dict(self): return { - "ip": self.ip + "ip": self.ip, } @@ -204,7 +216,7 @@ class Device: def to_dict(self): return { - "uuid": self.uuid + "uuid": self.uuid, } @@ -218,7 +230,7 @@ class ClientInfo: def to_dict(self): info = { "user_id": self.user_id, - "sdk": "python" + "sdk": "python", } if self.device is not None: info["device"] = self.device.to_dict() @@ -231,13 +243,13 @@ def to_dict(self): @dataclass class BizParams: - user_defined_params: dict = field(default=None) - user_defined_tokens: dict = field(default=None) - tool_prompts: dict = field(default=None) - user_prompt_params: dict = field(default=None) - user_query_params: dict = field(default=None) - videos: list = field(default=None) - pass_through_params: dict = field(default=None) + user_defined_params: dict = field(default=None) # type: ignore[arg-type] + user_defined_tokens: dict = field(default=None) # type: ignore[arg-type] + tool_prompts: dict = field(default=None) # type: ignore[arg-type] + user_prompt_params: dict = field(default=None) # type: ignore[arg-type] + user_query_params: dict = field(default=None) # type: ignore[arg-type] + videos: list = field(default=None) # type: ignore[arg-type] + pass_through_params: dict = field(default=None) # type: ignore[arg-type] def to_dict(self): params = {} @@ -282,12 +294,11 @@ def to_dict(self): @dataclass class RequestToRespondParameters(DashPayloadParameters): - images: list = field(default=None) + images: list = field(default=None) # type: ignore[arg-type] biz_params: BizParams = field(default=None) def to_dict(self): - params = { - } + params = {} if self.images is not None: params["images"] = self.images if self.biz_params is not None: @@ -300,7 +311,7 @@ class RequestToRespondBodyInput(DashPayloadInput): app_id: str directive: str dialog_id: str - type_: str = field(metadata={"alias": "type"}, default= None) + type_: str = field(metadata={"alias": "type"}, default=None) text: str = field(default="") def to_dict(self): @@ -309,5 +320,5 @@ def to_dict(self): "directive": self.directive, "dialog_id": self.dialog_id, "type": self.type_, - "text": self.text - } \ No newline at end of file + "text": self.text, + } diff --git a/dashscope/multimodal/tingwu/__init__.py b/dashscope/multimodal/tingwu/__init__.py index e11883b..3cc6b26 100644 --- a/dashscope/multimodal/tingwu/__init__.py +++ b/dashscope/multimodal/tingwu/__init__.py @@ -1,10 +1,11 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from .tingwu import TingWu from .tingwu_realtime import TingWuRealtime, TingWuRealtimeCallback __all__ = [ - 'TingWu', - 'TingWuRealtime', - 'TingWuRealtimeCallback' -] \ No newline at end of file + "TingWu", + "TingWuRealtime", + "TingWuRealtimeCallback", +] diff --git a/dashscope/multimodal/tingwu/tingwu.py b/dashscope/multimodal/tingwu/tingwu.py index 98101fa..15896f0 100644 --- a/dashscope/multimodal/tingwu/tingwu.py +++ b/dashscope/multimodal/tingwu/tingwu.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from typing import Dict, Any from dashscope.api_entities.api_request_factory import _build_api_request @@ -7,22 +8,20 @@ class TingWu(BaseApi): - """API for TingWu APP. - - """ + """API for TingWu APP.""" task = None task_group = None function = None @classmethod - def call( - cls, - model: str, - user_defined_input: Dict[str, Any], - parameters: Dict[str, Any] = None, - api_key: str = None, - **kwargs + def call( # type: ignore[override] + cls, + model: str, + user_defined_input: Dict[str, Any], + parameters: Dict[str, Any] = None, + api_key: str = None, + **kwargs, ) -> DashScopeAPIResponse: """Call generation model service. @@ -45,29 +44,34 @@ def call( stream is True, return Generator, otherwise GenerationResponse. """ if model is None or not model: - raise ModelRequired('Model is required!') - input_config, parameters = cls._build_input_parameters(input_config=user_defined_input, - params=parameters, - **kwargs) + raise ModelRequired("Model is required!") + input_config, parameters = cls._build_input_parameters( + input_config=user_defined_input, + params=parameters, + **kwargs, + ) request = _build_api_request( model=model, input=input_config, api_key=api_key, - task_group=TingWu.task_group, - task=TingWu.task, - function=TingWu.function, + task_group=TingWu.task_group, # type: ignore[arg-type] + task=TingWu.task, # type: ignore[arg-type] + function=TingWu.function, # type: ignore[arg-type] is_service=False, - **parameters) + **parameters, + ) response = request.call() return response @classmethod - def _build_input_parameters(cls, - input_config, - params: Dict[str, Any] = None, - **kwargs): + def _build_input_parameters( + cls, + input_config, + params: Dict[str, Any] = None, + **kwargs, + ): parameters = {} if params is not None: parameters = params @@ -75,6 +79,6 @@ def _build_input_parameters(cls, input_param = input_config if kwargs.keys() is not None: - for key in kwargs.keys(): + for key in kwargs: # pylint: disable=consider-using-dict-items parameters[key] = kwargs[key] return input_param, {**parameters, **kwargs} diff --git a/dashscope/multimodal/tingwu/tingwu_realtime.py b/dashscope/multimodal/tingwu/tingwu_realtime.py index be3d24d..ea8775c 100644 --- a/dashscope/multimodal/tingwu/tingwu_realtime.py +++ b/dashscope/multimodal/tingwu/tingwu_realtime.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -9,16 +10,17 @@ from queue import Queue import dashscope from dashscope.client.base_api import BaseApi -from dashscope.common.error import (InvalidParameter, ModelRequired) -import websocket +from dashscope.common.error import InvalidParameter, ModelRequired +import websocket # pylint: disable=wrong-import-order +# pylint: disable=ungrouped-imports from dashscope.common.logging import logger from dashscope.protocol.websocket import ActionType class TingWuRealtimeCallback: """An interface that defines callback methods for getting TingWu results. - Derive from this class and implement its function to provide your own data. + Derive from this class and implement its function to provide your own data. """ def on_open(self) -> None: @@ -49,7 +51,6 @@ def on_close(self, close_status_code, close_msg): :param close_status_code :param close_msg """ - pass class TingWuRealtime(BaseApi): @@ -75,30 +76,32 @@ class TingWuRealtime(BaseApi): SILENCE_TIMEOUT_S = 60 - def __init__(self, - model: str, - callback: TingWuRealtimeCallback, - audio_format: str = "pcm", - sample_rate: int = 16000, - max_end_silence: int = None, - app_id: str = None, - terminology: str = None, - workspace: str = None, - api_key: str = None, - base_address: str = None, - data_id: str = None, - **kwargs): + def __init__( + self, + model: str, + callback: TingWuRealtimeCallback, + audio_format: str = "pcm", + sample_rate: int = 16000, + max_end_silence: int = None, + app_id: str = None, + terminology: str = None, + workspace: str = None, + api_key: str = None, + base_address: str = None, + data_id: str = None, + **kwargs, + ): if api_key is None: self.api_key = dashscope.api_key else: - self.api_key = api_key + self.api_key = api_key # type: ignore[has-type] if base_address is None: self.base_address = dashscope.base_websocket_api_url else: - self.base_address = base_address + self.base_address = base_address # type: ignore[has-type] if model is None: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") self.data_id = data_id self.max_end_silence = max_end_silence @@ -123,16 +126,23 @@ def __init__(self, self.request_id_confirmed = False self.last_request_id = uuid.uuid4().hex self.request = _Request() - self.response = _TingWuResponse(self._callback, self.close) # 传递 self.close 作为回调 - - def _on_message(self, ws, message): + self.response = _TingWuResponse( + self._callback, + self.close, + ) # 传递 self.close 作为回调 + + def _on_message( # pylint: disable=unused-argument + self, + ws, + message, + ): logger.debug(f"<<<<<<< Received message: {message}") if isinstance(message, str): self.response.handle_text_response(message) elif isinstance(message, (bytes, bytearray)): self.response.handle_binary_response(message) - def _on_error(self, ws, error): + def _on_error(self, ws, error): # pylint: disable=unused-argument logger.error(f"Error: {error}") if self._callback: error_code = "" # 默认错误码 @@ -142,12 +152,23 @@ def _on_error(self, ws, error): error_code = "1002" # 超时错误 elif "authentication" in str(error).lower(): error_code = "1003" # 认证错误 - self._callback.on_error(error_code=error_code, error_msg=str(error)) + self._callback.on_error( + error_code=error_code, + error_msg=str(error), + ) - def _on_close(self, ws, close_status_code, close_msg): + def _on_close( # pylint: disable=unused-argument + self, + ws, + close_status_code, + close_msg, + ): try: logger.debug( - "WebSocket connection closed with status {} and message {}".format(close_status_code, close_msg)) + "WebSocket connection closed with status %s and message %s", # noqa: E501 + close_status_code, + close_msg, + ) if close_status_code is None: close_status_code = 1000 if close_msg is None: @@ -156,7 +177,7 @@ def _on_close(self, ws, close_status_code, close_msg): except Exception as e: logger.error(f"Error: {e}") - def _on_open(self, ws): + def _on_open(self, ws): # pylint: disable=unused-argument self._callback.on_open() self._running = True @@ -167,10 +188,12 @@ def start(self, **kwargs): """ interface for starting TingWu connection """ - assert self._callback is not None, 'Please set the callback to get the TingWu result.' # noqa E501 + assert ( + self._callback is not None + ), "Please set the callback to get the TingWu result." # noqa E501 if self._running: - raise InvalidParameter('TingWu client has started.') + raise InvalidParameter("TingWu client has started.") # self._start_stream_timestamp = -1 # self._first_package_timestamp = -1 @@ -195,7 +218,7 @@ def stop(self): _send_speech_json = self.request.generate_stop_request("stop") self._send_text_frame(_send_speech_json) - """inner class""" + """inner class""" # pylint: disable=pointless-string-statement def _send_start_request(self): """send start request""" @@ -209,7 +232,7 @@ def _send_start_request(self): terminology=self.terminology, max_end_silence=self.max_end_silence, data_id=self.data_id, - **self._kwargs + **self._kwargs, ) # send start request self._send_text_frame(_start_json) @@ -219,11 +242,14 @@ def _run_forever(self): def _connect(self, api_key: str): """init websocket connection""" - self.ws = websocket.WebSocketApp(self.base_address, header=self.request.get_websocket_header(api_key), - on_open=self._on_open, - on_message=self._on_message, - on_error=self._on_error, - on_close=self._on_close) + self.ws = websocket.WebSocketApp( + self.base_address, # type: ignore[has-type] + header=self.request.get_websocket_header(api_key), + on_open=self._on_open, + on_message=self._on_message, + on_error=self._on_error, + on_close=self._on_close, + ) self.thread = threading.Thread(target=self._run_forever) # 统一心跳机制配置 self.ws.ping_interval = 5 @@ -242,16 +268,19 @@ def _wait_for_connection(self): """wait for connection using event instead of busy waiting""" timeout = 5 start_time = time.time() - while not (self.ws.sock and self.ws.sock.connected) and (time.time() - start_time) < timeout: + while ( + not (self.ws.sock and self.ws.sock.connected) + and (time.time() - start_time) < timeout + ): time.sleep(0.1) # 短暂休眠,避免密集轮询 def _send_text_frame(self, text: str): # 避免在日志中记录敏感信息,如API密钥等 # 只记录非敏感信息 if '"Authorization"' not in text: - logger.info('>>>>>> send text frame : %s' % text) + logger.info(">>>>>> send text frame : %s", text) else: - logger.info('>>>>>> send text frame with authorization header') + logger.info(">>>>>> send text frame with authorization header") self.ws.send(text, websocket.ABNF.OPCODE_TEXT) def __send_binary_frame(self, binary: bytes): @@ -288,11 +317,11 @@ def send_audio_frame(self, buffer: bytes): InvalidParameter: Cannot send data to an uninitiated recognition. """ if self._running is False: - raise InvalidParameter('TingWu client has stopped.') + raise InvalidParameter("TingWu client has stopped.") if self._start_stream_timestamp < 0: self._start_stream_timestamp = time.time() * 1000 - logger.debug('send_audio_frame: {}'.format(len(buffer))) + logger.debug("send_audio_frame: %s", len(buffer)) self.__send_binary_frame(buffer) @@ -309,31 +338,32 @@ def __init__(self): self.workspace_id = None def get_websocket_header(self, api_key): - ua = 'dashscope/%s; python/%s; platform/%s; processor/%s' % ( - '1.18.0', # dashscope version - platform.python_version(), - platform.platform(), - platform.processor(), + ua = ( + f"dashscope/1.18.0; python/{platform.python_version()}; " + f"platform/{platform.platform()}; " + f"processor/{platform.processor()}" ) self.ws_headers = { "User-Agent": ua, "Authorization": f"bearer {api_key}", - "Accept": "application/json" + "Accept": "application/json", } - logger.info('websocket header: {}'.format(self.ws_headers)) + logger.info("websocket header: %s", self.ws_headers) return self.ws_headers - def generate_start_request(self, direction_name: str, - app_id: str, - model: str = None, - workspace_id: str = None, - audio_format: str = None, - sample_rate: int = None, - terminology: str = None, - max_end_silence: int = None, - data_id: str = None, - **kwargs - ) -> str: + def generate_start_request( + self, + direction_name: str, + app_id: str, + model: str = None, + workspace_id: str = None, + audio_format: str = None, + sample_rate: int = None, + terminology: str = None, + max_end_silence: int = None, + data_id: str = None, + **kwargs, + ) -> str: """ build start request. :param app_id: web console app id @@ -350,27 +380,36 @@ def generate_start_request(self, direction_name: str, : """ self._get_dash_request_header(ActionType.START) - parameters = self._get_start_parameters(audio_format=audio_format, sample_rate=sample_rate, - max_end_silence=max_end_silence, - terminology=terminology, - **kwargs) - self._get_dash_request_payload(direction_name=direction_name, app_id=app_id, workspace_id=workspace_id, - model=model, - data_id=data_id, - request_params=parameters) + parameters = self._get_start_parameters( + audio_format=audio_format, + sample_rate=sample_rate, + max_end_silence=max_end_silence, + terminology=terminology, + **kwargs, + ) + self._get_dash_request_payload( + direction_name=direction_name, + app_id=app_id, + workspace_id=workspace_id, + model=model, + data_id=data_id, + request_params=parameters, + ) cmd = { "header": self.header, - "payload": self.payload + "payload": self.payload, } return json.dumps(cmd) @staticmethod - def _get_start_parameters(audio_format: str = None, - sample_rate: int = None, - terminology: str = None, - max_end_silence: int = None, - **kwargs): + def _get_start_parameters( + audio_format: str = None, + sample_rate: int = None, + terminology: str = None, + max_end_silence: int = None, + **kwargs, + ): """ build start request parameters inner. :param kwargs: parameters @@ -378,13 +417,13 @@ def _get_start_parameters(audio_format: str = None, """ parameters = {} if audio_format is not None: - parameters['format'] = audio_format + parameters["format"] = audio_format if sample_rate is not None: - parameters['sampleRate'] = sample_rate + parameters["sampleRate"] = sample_rate if terminology is not None: - parameters['terminology'] = terminology + parameters["terminology"] = terminology if max_end_silence is not None: - parameters['maxEndSilence'] = max_end_silence + parameters["maxEndSilence"] = max_end_silence if kwargs is not None and len(kwargs) != 0: parameters.update(kwargs) return parameters @@ -400,7 +439,7 @@ def generate_stop_request(self, direction_name: str) -> str: cmd = { "header": self.header, - "payload": self.payload + "payload": self.payload, } return json.dumps(cmd) @@ -412,14 +451,16 @@ def _get_dash_request_header(self, action: str): self.task_id = get_random_uuid() self.header = DashHeader(action=action, task_id=self.task_id).to_dict() - def _get_dash_request_payload(self, direction_name: str, - app_id: str, - workspace_id: str = None, - custom_input=None, - model: str = None, - data_id: str = None, - request_params=None, - ): + def _get_dash_request_payload( + self, + direction_name: str, + app_id: str, + workspace_id: str = None, + custom_input=None, + model: str = None, + data_id: str = None, + request_params=None, + ): """ build start request payload inner. :param direction_name: inner direction name @@ -430,19 +471,19 @@ def _get_dash_request_payload(self, direction_name: str, :param model: model name """ if custom_input is not None: - input = custom_input + input = custom_input # pylint: disable=redefined-builtin else: input = RequestBodyInput( workspace_id=workspace_id, app_id=app_id, directive=direction_name, - data_id=data_id + data_id=data_id, ) self.payload = DashPayload( model=model, input=input.to_dict(), - parameters=request_params + parameters=request_params, ).to_dict() @@ -458,27 +499,35 @@ def handle_text_response(self, response_json: str): handle text response. :param response_json: json format response from server """ - logger.info("<<<<<< server response: %s" % response_json) + logger.info("<<<<<< server response: %s", response_json) try: # try to parse response as json json_data = json.loads(response_json) - header = json_data.get('header', {}) - if header.get('event') == 'task-failed': - logger.error('Server returned invalid message: %s' % response_json) + header = json_data.get("header", {}) + if header.get("event") == "task-failed": + logger.error( + "Server returned invalid message: %s", + response_json, + ) if self._callback: - self._callback.on_error(error_code=header.get('error_code'), - error_msg=header.get('error_message')) + self._callback.on_error( + error_code=header.get("error_code"), + error_msg=header.get("error_message"), + ) return - if header.get('event') == "task-started": - self._handle_started(header.get('task_id')) + if header.get("event") == "task-started": + self._handle_started(header.get("task_id")) return - payload = json_data.get('payload', {}) - output = payload.get('output', {}) + payload = json_data.get("payload", {}) + output = payload.get("output", {}) if output is not None: - action = output.get('action') - logger.info("Server response action: %s" % action) - self._handle_tingwu_agent_text_response(action=action, response_json=json_data) + action = output.get("action") + logger.info("Server response action: %s", action) + self._handle_tingwu_agent_text_response( + action=action, + response_json=json_data, + ) except json.JSONDecodeError: logger.error("Failed to parse message as JSON.") @@ -488,14 +537,23 @@ def handle_binary_response(self, response_binary: bytes): handle binary response. :param response_binary: server response binary。 """ - logger.info("<<<<<< server response binary length: %d" % len(response_binary)) + logger.info( + "<<<<<< server response binary length: %d", + len(response_binary), + ) - def _handle_tingwu_agent_text_response(self, action: str, response_json: dict): - payload = response_json.get('payload', {}) - output = payload.get('output', {}) + def _handle_tingwu_agent_text_response( + self, + action: str, + response_json: dict, + ): + payload = response_json.get("payload", {}) + output = payload.get("output", {}) if action == "task-failed": - self._callback.on_error(error_code=output.get('errorCode'), - error_msg=output.get('errorMessage')) + self._callback.on_error( + error_code=output.get("errorCode"), + error_msg=output.get("errorMessage"), + ) elif action == "speech-listen": self._callback.on_speech_listen(response_json) elif action == "recognize-result": @@ -507,11 +565,11 @@ def _handle_tingwu_agent_text_response(self, action: str, response_json: dict): if self._close_callback is not None: self._close_callback() else: - logger.info("Unknown response name:" + action) + logger.info("Unknown response name: %s", action) def _handle_started(self, task_id: str): self.task_id = task_id - self._callback.on_started(self.task_id) + self._callback.on_started(self.task_id) # type: ignore[arg-type] def get_random_uuid() -> str: @@ -520,7 +578,7 @@ def get_random_uuid() -> str: @dataclass -class RequestBodyInput(): +class RequestBodyInput: app_id: str directive: str data_id: str = field(default=None) @@ -549,7 +607,7 @@ def to_dict(self): "action": self.action, "task_id": self.task_id, "request_id": self.task_id, - "streaming": self.streaming + "streaming": self.streaming, } @@ -559,8 +617,8 @@ class DashPayload: function: str = field(default="generation") model: str = field(default="") task: str = field(default="multimodal-generation") - parameters: dict = field(default=None) - input: dict = field(default=None) + parameters: dict = field(default=None) # type: ignore[arg-type] + input: dict = field(default=None) # type: ignore[arg-type] def to_dict(self): payload = { diff --git a/dashscope/nlp/understanding.py b/dashscope/nlp/understanding.py index fafb998..137f0a4 100644 --- a/dashscope/nlp/understanding.py +++ b/dashscope/nlp/understanding.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from dashscope.api_entities.dashscope_response import DashScopeAPIResponse @@ -8,27 +9,30 @@ class Understanding(BaseApi): - nlu_task = 'nlu' + nlu_task = "nlu" """API for AI-Generated Content(AIGC) models. """ + class Models: - opennlu_v1 = 'opennlu-v1' + opennlu_v1 = "opennlu-v1" @classmethod - def call(cls, - model: str, - sentence: str = None, - labels: str = None, - task: str = None, - api_key: str = None, - **kwargs) -> DashScopeAPIResponse: + def call( # type: ignore[override] + cls, + model: str, + sentence: str = None, + labels: str = None, + task: str = None, + api_key: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Call generation model service. Args: model (str): The requested model, such as opennlu-v1 - sentence (str): The text content entered by the user that needs to be processed supports both Chinese and English. (The maximum limit for input is 1024 tokens, which is the sum of all input fields). # noqa E501 - labels (list): For the extraction task, label is the name of the type that needs to be extracted. For classification tasks, label is the classification system. Separate different labels with Chinese commas.. # noqa E501 + sentence (str): The text content entered by the user that needs to be processed supports both Chinese and English. (The maximum limit for input is 1024 tokens, which is the sum of all input fields). # noqa E501 # pylint: disable=line-too-long + labels (list): For the extraction task, label is the name of the type that needs to be extracted. For classification tasks, label is the classification system. Separate different labels with Chinese commas.. # noqa E501 # pylint: disable=line-too-long task (str): Task type, optional as extraction or classification, default as extraction. api_key (str, optional): The api api_key, can be None, if None, will get by default rule(TODO: api key doc). @@ -36,29 +40,50 @@ def call(cls, Returns: DashScopeAPIResponse: The understanding result. """ - if (sentence is None or not sentence) or (labels is None - or not labels): - raise InputRequired('sentence and labels is required!') + if (sentence is None or not sentence) or ( + labels is None or not labels + ): + raise InputRequired("sentence and labels is required!") if model is None or not model: - raise ModelRequired('Model is required!') - if kwargs.pop('stream', False): # not support stream - logger.warning('stream option not supported for Understanding.') + raise ModelRequired("Model is required!") + if kwargs.pop("stream", False): # not support stream + logger.warning("stream option not supported for Understanding.") task_group, function = _get_task_group_and_task(__name__) - input, parameters = cls._build_input_parameters( - model, sentence, labels, task, **kwargs) - return super().call(model=model, - task_group=task_group, - task=Understanding.nlu_task, - function=function, - api_key=api_key, - input=input, - **parameters) + ( + input, # pylint: disable=redefined-builtin + parameters, + ) = cls._build_input_parameters( + model, + sentence, + labels, + task, + **kwargs, + ) + return super().call( + model=model, + task_group=task_group, + task=Understanding.nlu_task, + function=function, + api_key=api_key, + input=input, + **parameters, + ) @classmethod - def _build_input_parameters(cls, model, sentence, labels, task, **kwargs): + def _build_input_parameters( # pylint: disable=unused-argument + cls, + model, + sentence, + labels, + task, + **kwargs, + ): parameters = {} - input = {'sentence': sentence, 'labels': labels} + input = { # pylint: disable=redefined-builtin + "sentence": sentence, + "labels": labels, + } if task is not None and task: - input['task'] = task + input["task"] = task return input, {**parameters, **kwargs} diff --git a/dashscope/protocol/websocket.py b/dashscope/protocol/websocket.py index bf2f502..e650e86 100644 --- a/dashscope/protocol/websocket.py +++ b/dashscope/protocol/websocket.py @@ -1,30 +1,31 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. class WebsocketStreamingMode: # TODO how to know request is duplex or other. - NONE = 'none' # no stream - IN = 'in' # stream in - OUT = 'out' - DUPLEX = 'duplex' + NONE = "none" # no stream + IN = "in" # stream in + OUT = "out" + DUPLEX = "duplex" -ACTION_KEY = 'action' -EVENT_KEY = 'event' -HEADER = 'header' -TASK_ID = 'task_id' -ERROR_NAME = 'error_code' -ERROR_MESSAGE = 'error_message' +ACTION_KEY = "action" +EVENT_KEY = "event" +HEADER = "header" +TASK_ID = "task_id" +ERROR_NAME = "error_code" +ERROR_MESSAGE = "error_message" class EventType: - STARTED = 'task-started' - GENERATED = 'result-generated' - FINISHED = 'task-finished' - FAILED = 'task-failed' + STARTED = "task-started" + GENERATED = "result-generated" + FINISHED = "task-finished" + FAILED = "task-failed" class ActionType: - START = 'run-task' - CONTINUE = 'continue-task' - FINISHED = 'finish-task' + START = "run-task" + CONTINUE = "continue-task" + FINISHED = "finish-task" diff --git a/dashscope/rerank/text_rerank.py b/dashscope/rerank/text_rerank.py index 9f4cd39..c38bf1d 100644 --- a/dashscope/rerank/text_rerank.py +++ b/dashscope/rerank/text_rerank.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from typing import List @@ -9,22 +10,25 @@ class TextReRank(BaseApi): - task = 'text-rerank' + task = "text-rerank" """API for rerank models. """ + class Models: - gte_rerank = 'gte-rerank' + gte_rerank = "gte-rerank" @classmethod - def call(cls, - model: str, - query: str, - documents: List[str], - return_documents: bool = None, - top_n: int = None, - api_key: str = None, - **kwargs) -> ReRankResponse: + def call( # type: ignore[override] + cls, + model: str, + query: str, + documents: List[str], + return_documents: bool = None, + top_n: int = None, + api_key: str = None, + **kwargs, + ) -> ReRankResponse: """Calling rerank service. Args: @@ -33,7 +37,7 @@ def call(cls, documents (List[str]): The documents to rank. return_documents(bool, `optional`): enable return origin documents, system default is false. - top_n(int, `optional`): how many documents to return, default return + top_n(int, `optional`): how many documents to return, default return # noqa: E501 all the documents. api_key (str, optional): The DashScope api key. Defaults to None. @@ -46,24 +50,29 @@ def call(cls, """ if query is None or documents is None or not documents: - raise InputRequired('query and documents are required!') + raise InputRequired("query and documents are required!") if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") task_group, function = _get_task_group_and_task(__name__) - input = {'query': query, 'documents': documents} + input = { # pylint: disable=redefined-builtin + "query": query, + "documents": documents, + } parameters = {} if return_documents is not None: - parameters['return_documents'] = return_documents + parameters["return_documents"] = return_documents if top_n is not None: - parameters['top_n'] = top_n + parameters["top_n"] = top_n parameters = {**parameters, **kwargs} - response = super().call(model=model, - task_group=task_group, - task=TextReRank.task, - function=function, - api_key=api_key, - input=input, - **parameters) + response = super().call( + model=model, + task_group=task_group, + task=TextReRank.task, + function=function, + api_key=api_key, + input=input, + **parameters, # type: ignore[arg-type] + ) return ReRankResponse.from_api_response(response) diff --git a/dashscope/threads/__init__.py b/dashscope/threads/__init__.py index ae19c26..5ae531f 100644 --- a/dashscope/threads/__init__.py +++ b/dashscope/threads/__init__.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. # yapf: disable @@ -5,22 +6,24 @@ from dashscope.threads.messages.messages import Messages from dashscope.threads.runs.runs import Runs from dashscope.threads.runs.steps import Steps -from dashscope.threads.thread_types import (MessageFile, Run, RunList, RunStep, - RunStepList, Thread, ThreadMessage, - ThreadMessageList) +from dashscope.threads.thread_types import ( + MessageFile, Run, RunList, RunStep, + RunStepList, Thread, ThreadMessage, + ThreadMessageList, +) from dashscope.threads.threads import Threads __all__ = [ - MessageFile, - Messages, - Run, - Runs, - RunList, - Steps, - RunStep, - RunStepList, - Threads, - Thread, - ThreadMessage, - ThreadMessageList, + 'MessageFile', + 'Messages', + 'Run', + 'Runs', + 'RunList', + 'Steps', + 'RunStep', + 'RunStepList', + 'Threads', + 'Thread', + 'ThreadMessage', + 'ThreadMessageList', ] diff --git a/dashscope/threads/messages/files.py b/dashscope/threads/messages/files.py index ba6aef9..81e5c00 100644 --- a/dashscope/threads/messages/files.py +++ b/dashscope/threads/messages/files.py @@ -1,24 +1,27 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from dashscope.client.base_api import GetStatusMixin, ListObjectMixin from dashscope.common.error import InputRequired from dashscope.threads.thread_types import MessageFile, MessageFileList -__all__ = ['Files'] +__all__ = ["Files"] class Files(ListObjectMixin, GetStatusMixin): - SUB_PATH = 'messages' # useless + SUB_PATH = "messages" # useless @classmethod - def retrieve(cls, - file_id: str, - *, - thread_id: str, - message_id: str, - workspace: str = None, - api_key: str = None, - **kwargs) -> MessageFile: + def retrieve( + cls, + file_id: str, + *, + thread_id: str, + message_id: str, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> MessageFile: """Retrieve the `MessageFile`. Args: @@ -31,22 +34,26 @@ def retrieve(cls, Returns: MessageFile: The `MessageFile` object. """ - return cls.get(file_id, - thread_id=thread_id, - message_id=message_id, - workspace=workspace, - api_key=api_key, - **kwargs) + return cls.get( + file_id, + thread_id=thread_id, + message_id=message_id, + workspace=workspace, + api_key=api_key, + **kwargs, + ) @classmethod - def get(cls, - file_id: str, - *, - message_id: str, - thread_id: str, - workspace: str = None, - api_key: str = None, - **kwargs) -> MessageFile: + def get( # type: ignore[override] + cls, + file_id: str, + *, + message_id: str, + thread_id: str, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> MessageFile: """Retrieve the `MessageFile`. Args: @@ -61,53 +68,60 @@ def get(cls, """ if not thread_id or not message_id or not file_id: raise InputRequired( - 'thread id, message id and file id are required!') + "thread id, message id and file id are required!", + ) response = super().get( message_id, - path=f'threads/{thread_id}/messages/{message_id}/files/{file_id}', + path=f"threads/{thread_id}/messages/{message_id}/files/{file_id}", workspace=workspace, api_key=api_key, flattened_output=True, - **kwargs) + **kwargs, + ) return MessageFile(**response) @classmethod - def list(cls, - message_id: str, - *, - thread_id: str, - limit: int = None, - order: str = None, - after: str = None, - before: str = None, - workspace: str = None, - api_key: str = None, - **kwargs) -> MessageFileList: + def list( # type: ignore[override] + cls, + message_id: str, + *, + thread_id: str, + limit: int = None, + order: str = None, + after: str = None, + before: str = None, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> MessageFileList: """List message files. Args: thread_id (str): The thread id. message_id (str): The message_id. - limit (int, optional): How many assistant to retrieve. Defaults to None. + limit (int, optional): + How many assistant to retrieve. Defaults to None. order (str, optional): Sort order by created_at. Defaults to None. after (str, optional): Assistant id after. Defaults to None. before (str, optional): Assistant id before. Defaults to None. - workspace (str, optional): The DashScope workspace id. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): Your DashScope api key. Defaults to None. Returns: MessageFileList: The `MessageFileList`. """ if not thread_id or not message_id: - raise InputRequired('thread id, message id are required!') + raise InputRequired("thread id, message id are required!") response = super().list( limit=limit, order=order, after=after, before=before, - path=f'threads/{thread_id}/messages/{message_id}/files', + path=f"threads/{thread_id}/messages/{message_id}/files", workspace=workspace, api_key=api_key, flattened_output=True, - **kwargs) + **kwargs, + ) return MessageFileList(**response) diff --git a/dashscope/threads/messages/messages.py b/dashscope/threads/messages/messages.py index a673ac0..2420609 100644 --- a/dashscope/threads/messages/messages.py +++ b/dashscope/threads/messages/messages.py @@ -1,102 +1,125 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Dict, List, Optional -from dashscope.client.base_api import (CreateMixin, GetStatusMixin, - ListObjectMixin, UpdateMixin) +from dashscope.client.base_api import ( + CreateMixin, + GetStatusMixin, + ListObjectMixin, + UpdateMixin, +) from dashscope.common.error import InputRequired from dashscope.threads.thread_types import ThreadMessage, ThreadMessageList -__all__ = ['Messages'] +__all__ = ["Messages"] class Messages(CreateMixin, ListObjectMixin, GetStatusMixin, UpdateMixin): - SUB_PATH = 'messages' # useless + SUB_PATH = "messages" # useless @classmethod - def call(cls, - thread_id: str, - *, - content: str, - role: str = 'user', - file_ids: List[str] = [], - metadata: Optional[object] = None, - workspace: str = None, - api_key: str = None, - **kwargs) -> ThreadMessage: + # pylint: disable=dangerous-default-value + def call( # type: ignore[override] + cls, + thread_id: str, + *, + content: str, + role: str = "user", + file_ids: List[str] = [], + metadata: Optional[object] = None, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> ThreadMessage: """Create message of thread. Args: thread_id (str): The thread id. content (str): The message content. role (str, optional): The message role. Defaults to 'user'. - file_ids (List[str], optional): The file_ids include in message. Defaults to []. - metadata (Optional[object], optional): The custom key/value pairs. Defaults to None. - workspace (str, optional): The DashScope workspace id. Defaults to None. + file_ids (List[str], optional): + The file_ids include in message. Defaults to []. + metadata (Optional[object], optional): + The custom key/value pairs. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): The DashScope api key. Defaults to None. Returns: ThreadMessage: The `ThreadMessage` object. """ - return cls.create(thread_id, - content=content, - role=role, - file_ids=file_ids, - metadata=metadata, - workspace=workspace, - **kwargs) + return cls.create( + thread_id, + content=content, + role=role, + file_ids=file_ids, + metadata=metadata, + workspace=workspace, + **kwargs, + ) @classmethod - def create(cls, - thread_id: str, - *, - content: str, - role: str = 'user', - file_ids: List[str] = [], - metadata: Optional[object] = None, - workspace: str = None, - api_key: str = None, - **kwargs) -> ThreadMessage: + # pylint: disable=dangerous-default-value + def create( + cls, + thread_id: str, + *, + content: str, + role: str = "user", + file_ids: List[str] = [], + metadata: Optional[object] = None, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> ThreadMessage: """Create message of thread. Args: thread_id (str): The thread id. content (str): The message content. role (str, optional): The message role. Defaults to 'user'. - file_ids (List[str], optional): The file_ids include in message. Defaults to []. - metadata (Optional[object], optional): The custom key/value pairs. Defaults to None. - workspace (str, optional): The DashScope workspace id. Defaults to None. + file_ids (List[str], optional): + The file_ids include in message. Defaults to []. + metadata (Optional[object], optional): + The custom key/value pairs. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): The DashScope api key. Defaults to None. Returns: ThreadMessage: The `ThreadMessage` object. """ - cls.SUB_PATH = '%s/messages' % thread_id + cls.SUB_PATH = f"{thread_id}/messages" data = {} if not thread_id or not content: - raise InputRequired('thread_id and content are required!') - data['content'] = content - data['role'] = role + raise InputRequired("thread_id and content are required!") + data["content"] = content + data["role"] = role if metadata: - data['metadata'] = metadata + data["metadata"] = metadata if file_ids: - data['file_ids'] = file_ids - response = super().call(data=data, - path=f'threads/{thread_id}/messages', - api_key=api_key, - flattened_output=True, - workspace=workspace, - **kwargs) + data["file_ids"] = file_ids + response = super().call( + data=data, + path=f"threads/{thread_id}/messages", + api_key=api_key, + flattened_output=True, + workspace=workspace, + **kwargs, + ) return ThreadMessage(**response) @classmethod - def retrieve(cls, - message_id: str, - *, - thread_id: str, - workspace: str = None, - api_key: str = None, - **kwargs) -> ThreadMessage: + def retrieve( + cls, + message_id: str, + *, + thread_id: str, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> ThreadMessage: """Get the `ThreadMessage`. Args: @@ -108,20 +131,24 @@ def retrieve(cls, Returns: ThreadMessage: The `ThreadMessage` object. """ - return cls.get(message_id, - thread_id=thread_id, - workspace=workspace, - api_key=api_key, - **kwargs) + return cls.get( + message_id, + thread_id=thread_id, + workspace=workspace, + api_key=api_key, + **kwargs, + ) @classmethod - def get(cls, - message_id: str, - *, - thread_id: str, - workspace: str = None, - api_key: str = None, - **kwargs) -> ThreadMessage: + def get( # type: ignore[override] + cls, + message_id: str, + *, + thread_id: str, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> ThreadMessage: """Get the `ThreadMessage`. Args: @@ -134,63 +161,72 @@ def get(cls, ThreadMessage: The `ThreadMessage` object. """ if not message_id or not thread_id: - raise InputRequired('thread id, message id are required!') + raise InputRequired("thread id, message id are required!") response = super().get( message_id, - path=f'threads/{thread_id}/messages/{message_id}', + path=f"threads/{thread_id}/messages/{message_id}", workspace=workspace, api_key=api_key, flattened_output=True, - **kwargs) + **kwargs, + ) return ThreadMessage(**response) @classmethod - def list(cls, - thread_id: str, - *, - limit: int = None, - order: str = None, - after: str = None, - before: str = None, - workspace: str = None, - api_key: str = None, - **kwargs) -> ThreadMessageList: + def list( # type: ignore[override] + cls, + thread_id: str, + *, + limit: int = None, + order: str = None, + after: str = None, + before: str = None, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> ThreadMessageList: """List message of the thread. Args: thread_id (str): The thread id. - limit (int, optional): How many assistant to retrieve. Defaults to None. + limit (int, optional): + How many assistant to retrieve. Defaults to None. order (str, optional): Sort order by created_at. Defaults to None. after (str, optional): Assistant id after. Defaults to None. before (str, optional): Assistant id before. Defaults to None. - workspace (str, optional): The DashScope workspace id. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): Your DashScope api key. Defaults to None. Returns: ThreadMessageList: The `ThreadMessageList` object. """ if not thread_id: - raise InputRequired('thread id is required!') - response = super().list(limit=limit, - order=order, - after=after, - before=before, - path=f'threads/{thread_id}/messages', - workspace=workspace, - api_key=api_key, - flattened_output=True, - **kwargs) + raise InputRequired("thread id is required!") + response = super().list( + limit=limit, + order=order, + after=after, + before=before, + path=f"threads/{thread_id}/messages", + workspace=workspace, + api_key=api_key, + flattened_output=True, + **kwargs, + ) return ThreadMessageList(**response) @classmethod - def update(cls, - message_id: str, - *, - thread_id: str, - metadata: Dict = None, - workspace: str = None, - api_key: str = None, - **kwargs) -> ThreadMessage: + def update( # type: ignore[override] + cls, + message_id: str, + *, + thread_id: str, + metadata: Dict = None, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> ThreadMessage: """Update an message of the thread. Args: @@ -198,23 +234,27 @@ def update(cls, message_id (str): The message id. content (str): The message content. role (str, optional): The message role. Defaults to 'user'. - file_ids (List[str], optional): The file_ids include in message. Defaults to []. - metadata (Optional[object], optional): The custom key/value pairs. Defaults to None. - workspace (str, optional): The DashScope workspace id. Defaults to None. + file_ids (List[str], optional): + The file_ids include in message. Defaults to []. + metadata (Optional[object], optional): + The custom key/value pairs. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): The DashScope api key. Defaults to None. Returns: ThreadMessage: The `ThreadMessage` object. """ if not thread_id or not message_id: - raise InputRequired('thread id and message id are required!') - response = super().update(target=message_id, - json={'metadata': metadata}, - path='threads/%s/messages/%s' % - (thread_id, message_id), - api_key=api_key, - workspace=workspace, - flattened_output=True, - method='post', - **kwargs) + raise InputRequired("thread id and message id are required!") + response = super().update( + target=message_id, + json={"metadata": metadata}, + path=f"threads/{thread_id}/messages/{message_id}", + api_key=api_key, + workspace=workspace, + flattened_output=True, + method="post", + **kwargs, + ) return ThreadMessage(**response) diff --git a/dashscope/threads/runs/runs.py b/dashscope/threads/runs/runs.py index cba04e3..ed5f30f 100644 --- a/dashscope/threads/runs/runs.py +++ b/dashscope/threads/runs/runs.py @@ -1,107 +1,138 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import time from http import HTTPStatus from typing import Dict, List, Optional -from dashscope.client.base_api import (CancelMixin, CreateMixin, - GetStatusMixin, ListObjectMixin, - UpdateMixin) -from dashscope.common.error import (AssistantError, InputRequired, - TimeoutException) +from dashscope.client.base_api import ( + CancelMixin, + CreateMixin, + GetStatusMixin, + ListObjectMixin, + UpdateMixin, +) +from dashscope.common.error import ( + AssistantError, + InputRequired, + TimeoutException, +) from dashscope.common.logging import logger -from dashscope.threads.thread_types import (Run, RunList, RunStep, - RunStepDelta, Thread, - ThreadMessage, ThreadMessageDelta) - -__all__ = ['Runs'] - - -class Runs(CreateMixin, CancelMixin, ListObjectMixin, GetStatusMixin, - UpdateMixin): - SUB_PATH = 'RUNS' # useless +from dashscope.threads.thread_types import ( + Run, + RunList, + RunStep, + RunStepDelta, + Thread, + ThreadMessage, + ThreadMessageDelta, +) + +__all__ = ["Runs"] + + +class Runs( + CreateMixin, + CancelMixin, + ListObjectMixin, + GetStatusMixin, + UpdateMixin, +): + SUB_PATH = "RUNS" # useless @classmethod - def create_thread_and_run(cls, - *, - assistant_id: str, - thread: Optional[Dict] = None, - model: Optional[str] = None, - instructions: Optional[str] = None, - additional_instructions: Optional[str] = None, - tools: Optional[List[Dict]] = None, - stream: Optional[bool] = False, - metadata: Optional[Dict] = None, - workspace: str = None, - extra_body: Optional[Dict] = None, - api_key: str = None, - **kwargs) -> Run: + def create_thread_and_run( + cls, + *, + assistant_id: str, + thread: Optional[Dict] = None, + model: Optional[str] = None, + instructions: Optional[str] = None, + additional_instructions: Optional[str] = None, + tools: Optional[List[Dict]] = None, + stream: Optional[bool] = False, + metadata: Optional[Dict] = None, + workspace: str = None, + extra_body: Optional[Dict] = None, + api_key: str = None, + **kwargs, + ) -> Run: if not assistant_id: - raise InputRequired('assistant_id is required') - data = {'assistant_id': assistant_id} + raise InputRequired("assistant_id is required") + data = {"assistant_id": assistant_id} if thread: - data['thread'] = thread + data["thread"] = thread if model: - data['model'] = model + data["model"] = model if instructions: - data['instructions'] = instructions + data["instructions"] = instructions if additional_instructions: - data['additional_instructions'] = additional_instructions + data["additional_instructions"] = additional_instructions if tools: - data['tools'] = tools + data["tools"] = tools if metadata: - data['metadata'] = metadata - data['stream'] = stream + data["metadata"] = metadata + data["stream"] = stream if extra_body is not None and extra_body: data = {**data, **extra_body} - response = super().call(data=data, - path='threads/runs', - api_key=api_key, - flattened_output=True, - stream=stream, - workspace=workspace, - **kwargs) + response = super().call( + data=data, + path="threads/runs", + api_key=api_key, + flattened_output=True, + stream=stream, # type: ignore[arg-type] + workspace=workspace, + **kwargs, + ) if stream: - return ((event_type, cls.convert_stream_object(event_type, item)) - for event_type, item in response) + return ( # type: ignore[return-value] + (event_type, cls.convert_stream_object(event_type, item)) + for event_type, item in response + ) else: return Run(**response) @classmethod - def create(cls, - thread_id: str, - *, - assistant_id: str, - model: Optional[str] = None, - instructions: Optional[str] = None, - additional_instructions: Optional[str] = None, - tools: Optional[List[Dict]] = None, - metadata: Optional[Dict] = None, - stream: Optional[bool] = False, - workspace: str = None, - extra_body: Optional[Dict] = None, - api_key: str = None, - top_p: Optional[float] = None, - top_k: Optional[int] = None, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, - **kwargs) -> Run: + def create( # pylint: disable=too-many-branches + cls, + thread_id: str, + *, + assistant_id: str, + model: Optional[str] = None, + instructions: Optional[str] = None, + additional_instructions: Optional[str] = None, + tools: Optional[List[Dict]] = None, + metadata: Optional[Dict] = None, + stream: Optional[bool] = False, + workspace: str = None, + extra_body: Optional[Dict] = None, + api_key: str = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + **kwargs, + ) -> Run: """Create a run. Args: thread_id (str): The thread to run. assistant_id (str): The assistant id to run. model (str): The model to use. - instructions (str, optional): The system instructions this assistant uses. Defaults to None. - additional_instructions (Optional[str], optional): Appends additional + instructions (str, optional): + The system instructions this assistant uses. Defaults to None. + additional_instructions (Optional[str], optional): + Appends additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions.. Defaults to None. - tools (Optional[str], optional): List of tools to use.. Defaults to []. - metadata (Dict, optional): Custom key-value pairs associate with run. Defaults to None. + tools (Optional[str], optional): List of tools to use.. Defaults to []. # noqa: E501 + metadata (Dict, optional): + Custom key-value pairs associate with run. Defaults to None. workspace (str): The DashScope workspace id. - api_key (str, optional): The DashScope workspace id. Defaults to None. + api_key (str, optional): + The DashScope workspace id. Defaults to None. Raises: InputRequired: thread and assistant is required. @@ -110,109 +141,119 @@ def create(cls, Run: The `Run` object. """ if not thread_id or not assistant_id: - raise InputRequired('thread_id and assistant_id is required') - data = {'assistant_id': assistant_id} + raise InputRequired("thread_id and assistant_id is required") + data = {"assistant_id": assistant_id} if model: - data['model'] = model + data["model"] = model if instructions: - data['instructions'] = instructions + data["instructions"] = instructions if additional_instructions: - data['additional_instructions'] = additional_instructions + data["additional_instructions"] = additional_instructions if tools: - data['tools'] = tools + data["tools"] = tools if metadata: - data['metadata'] = metadata - data['stream'] = stream + data["metadata"] = metadata + data["stream"] = stream if extra_body is not None and extra_body: data = {**data, **extra_body} if top_p is not None: - data['top_p'] = top_p + data["top_p"] = top_p if top_k is not None: - data['top_k'] = top_k + data["top_k"] = top_k if temperature is not None: - data['temperature'] = temperature + data["temperature"] = temperature if max_tokens is not None: - data['max_tokens'] = max_tokens - - response = super().call(data=data, - path=f'threads/{thread_id}/runs', - api_key=api_key, - flattened_output=True, - stream=stream, - workspace=workspace, - **kwargs) + data["max_tokens"] = max_tokens + + response = super().call( + data=data, + path=f"threads/{thread_id}/runs", + api_key=api_key, + flattened_output=True, + stream=stream, # type: ignore[arg-type] + workspace=workspace, + **kwargs, + ) if stream: - return ((event_type, cls.convert_stream_object(event_type, item)) - for event_type, item in response) + return ( # type: ignore[return-value] + (event_type, cls.convert_stream_object(event_type, item)) + for event_type, item in response + ) else: return Run(**response) @classmethod def convert_stream_object(cls, event, item): event_object_map = { - 'thread.created': Thread, - 'thread.run.created': Run, - 'thread.run.queued': Run, - 'thread.run.in_progress': Run, - 'thread.run.requires_action': Run, - 'thread.run.completed': Run, - 'thread.run.failed': Run, - 'thread.run.cancelled': Run, - 'thread.run.expired': Run, - 'thread.run.step.created': RunStep, - 'thread.run.step.in_progress': RunStep, - 'thread.run.step.delta': RunStepDelta, - 'thread.run.step.completed': RunStep, - 'thread.run.step.failed': RunStep, - 'thread.run.step.cancelled': RunStep, - 'thread.run.step.expired': RunStep, - 'thread.message.created': ThreadMessage, - 'thread.message.in_progress': ThreadMessage, - 'thread.message.delta': ThreadMessageDelta, - 'thread.message.completed': ThreadMessage, - 'thread.message.incomplete': ThreadMessage, - 'error': AssistantError, + "thread.created": Thread, + "thread.run.created": Run, + "thread.run.queued": Run, + "thread.run.in_progress": Run, + "thread.run.requires_action": Run, + "thread.run.completed": Run, + "thread.run.failed": Run, + "thread.run.cancelled": Run, + "thread.run.expired": Run, + "thread.run.step.created": RunStep, + "thread.run.step.in_progress": RunStep, + "thread.run.step.delta": RunStepDelta, + "thread.run.step.completed": RunStep, + "thread.run.step.failed": RunStep, + "thread.run.step.cancelled": RunStep, + "thread.run.step.expired": RunStep, + "thread.message.created": ThreadMessage, + "thread.message.in_progress": ThreadMessage, + "thread.message.delta": ThreadMessageDelta, + "thread.message.completed": ThreadMessage, + "thread.message.incomplete": ThreadMessage, + "error": AssistantError, } - if (event in event_object_map): + if event in event_object_map: return event_object_map[event](**item) else: return item @classmethod - def call(cls, - thread_id: str, - *, - assistant_id: str, - model: Optional[str] = None, - instructions: Optional[str] = None, - additional_instructions: Optional[str] = None, - tools: Optional[List[Dict]] = None, - stream: Optional[bool] = False, - metadata: Optional[Dict] = None, - workspace: str = None, - extra_body: Optional[Dict] = None, - api_key: str = None, - top_p: Optional[float] = None, - top_k: Optional[int] = None, - temperature: Optional[float] = None, - max_tokens: Optional[int] = None, - **kwargs) -> Run: + def call( # type: ignore[override] + cls, + thread_id: str, + *, + assistant_id: str, + model: Optional[str] = None, + instructions: Optional[str] = None, + additional_instructions: Optional[str] = None, + tools: Optional[List[Dict]] = None, + stream: Optional[bool] = False, + metadata: Optional[Dict] = None, + workspace: str = None, + extra_body: Optional[Dict] = None, + api_key: str = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + **kwargs, + ) -> Run: """Create a run. Args: thread_id (str): The thread to run. assistant_id (str): The assistant id to run. model (str): The model to use. - instructions (str, optional): The system instructions this assistant uses. Defaults to None. - additional_instructions (Optional[str], optional): Appends additional + instructions (str, optional): + The system instructions this assistant uses. Defaults to None. + additional_instructions (Optional[str], optional): + Appends additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions.. Defaults to None. - tools (Optional[str], optional): List of tools to use.. Defaults to []. - metadata (Dict, optional): Custom key-value pairs associate with run. Defaults to None. + tools (Optional[str], optional): List of tools to use.. Defaults to []. # noqa: E501 + metadata (Dict, optional): + Custom key-value pairs associate with run. Defaults to None. workspace (str): The DashScope workspace id. - api_key (str, optional): The DashScope workspace id. Defaults to None. + api_key (str, optional): + The DashScope workspace id. Defaults to None. Raises: InputRequired: thread and assistant is required. @@ -220,69 +261,79 @@ def call(cls, Returns: Run: The `Run` object. """ - return cls.create(thread_id, - assistant_id=assistant_id, - model=model, - instructions=instructions, - additional_instructions=additional_instructions, - tools=tools, - stream=stream, - metadata=metadata, - workspace=workspace, - extra_body=extra_body, - api_key=api_key, - top_p=top_p, - top_k=top_k, - temperature=temperature, - max_tokens=max_tokens, - **kwargs) + return cls.create( + thread_id, + assistant_id=assistant_id, + model=model, + instructions=instructions, + additional_instructions=additional_instructions, + tools=tools, + stream=stream, + metadata=metadata, + workspace=workspace, + extra_body=extra_body, + api_key=api_key, + top_p=top_p, + top_k=top_k, + temperature=temperature, + max_tokens=max_tokens, + **kwargs, + ) @classmethod - def list(cls, - thread_id: str, - *, - limit: int = None, - order: str = None, - after: str = None, - before: str = None, - workspace: str = None, - api_key: str = None, - **kwargs) -> RunList: + def list( # type: ignore[override] + cls, + thread_id: str, + *, + limit: int = None, + order: str = None, + after: str = None, + before: str = None, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> RunList: """List `Run`. Args: thread_id (str): The thread id. - limit (int, optional): How many assistant to retrieve. Defaults to None. + limit (int, optional): + How many assistant to retrieve. Defaults to None. order (str, optional): Sort order by created_at. Defaults to None. after (str, optional): Assistant id after. Defaults to None. before (str, optional): Assistant id before. Defaults to None. - workspace (str, optional): The DashScope workspace id. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): Your DashScope api key. Defaults to None. Returns: RunList: The list of runs. """ if not thread_id: - raise InputRequired('thread_id is required!') - response = super().list(limit=limit, - order=order, - after=after, - before=before, - path=f'threads/{thread_id}/runs', - workspace=workspace, - api_key=api_key, - flattened_output=True, - **kwargs) + raise InputRequired("thread_id is required!") + response = super().list( + limit=limit, + order=order, + after=after, + before=before, + path=f"threads/{thread_id}/runs", + workspace=workspace, + api_key=api_key, + flattened_output=True, + **kwargs, + ) return RunList(**response) @classmethod - def retrieve(cls, - run_id: str, - *, - thread_id: str, - workspace: str = None, - api_key: str = None, - **kwargs) -> Run: + def retrieve( + cls, + run_id: str, + *, + thread_id: str, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> Run: """Retrieve the `Run`. Args: @@ -295,23 +346,27 @@ def retrieve(cls, Run: The `Run` object. """ if not thread_id or not run_id: - raise InputRequired('thread_id and run_id are required!') - response = super().get(run_id, - path=f'threads/{thread_id}/runs/{run_id}', - workspace=workspace, - api_key=api_key, - flattened_output=True, - **kwargs) + raise InputRequired("thread_id and run_id are required!") + response = super().get( + run_id, + path=f"threads/{thread_id}/runs/{run_id}", + workspace=workspace, + api_key=api_key, + flattened_output=True, + **kwargs, + ) return Run(**response) @classmethod - def get(cls, - run_id: str, - *, - thread_id: str, - workspace: str = None, - api_key: str = None, - **kwargs) -> Run: + def get( # type: ignore[override] + cls, + run_id: str, + *, + thread_id: str, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> Run: """Retrieve the `Run`. Args: @@ -323,23 +378,27 @@ def get(cls, Returns: Run: The `Run` object. """ - return cls.retrieve(run_id, - thread_id=thread_id, - workspace=workspace, - api_key=api_key, - **kwargs) + return cls.retrieve( + run_id, + thread_id=thread_id, + workspace=workspace, + api_key=api_key, + **kwargs, + ) @classmethod - def submit_tool_outputs(cls, - run_id: str, - *, - thread_id: str, - tool_outputs: List[Dict], - stream: Optional[bool] = False, - workspace: str = None, - extra_body: Optional[Dict] = None, - api_key: str = None, - **kwargs) -> Run: + def submit_tool_outputs( + cls, + run_id: str, + *, + thread_id: str, + tool_outputs: List[Dict], + stream: Optional[bool] = False, + workspace: str = None, + extra_body: Optional[Dict] = None, + api_key: str = None, + **kwargs, + ) -> Run: """_summary_ Args: @@ -357,38 +416,43 @@ def submit_tool_outputs(cls, Run: The 'Run`. """ if not tool_outputs: - raise InputRequired('tool_outputs is required!') + raise InputRequired("tool_outputs is required!") if not thread_id or not run_id: - raise InputRequired('thread_id and run_id are required!') + raise InputRequired("thread_id and run_id are required!") - data = {'tool_outputs': tool_outputs} - data['stream'] = stream + data = {"tool_outputs": tool_outputs} + data["stream"] = stream if extra_body is not None and extra_body: data = {**data, **extra_body} response = super().call( data, - path=f'threads/{thread_id}/runs/{run_id}/submit_tool_outputs', + path=f"threads/{thread_id}/runs/{run_id}/submit_tool_outputs", workspace=workspace, api_key=api_key, - stream=stream, + stream=stream, # type: ignore[arg-type] flattened_output=True, - **kwargs) + **kwargs, + ) if stream: - return ((event_type, cls.convert_stream_object(event_type, item)) - for event_type, item in response) + return ( # type: ignore[return-value] + (event_type, cls.convert_stream_object(event_type, item)) + for event_type, item in response + ) else: return Run(**response) @classmethod - def wait(cls, - run_id: str, - *, - thread_id: str, - timeout_seconds: float = float('inf'), - workspace: str = None, - api_key: str = None, - **kwargs) -> Run: + def wait( # pylint: disable=unused-argument + cls, + run_id: str, + *, + thread_id: str, + timeout_seconds: float = float("inf"), + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> Run: """Wait for run to complete. Args: @@ -402,53 +466,66 @@ def wait(cls, Run: The run final status. """ if not run_id or not thread_id: - raise InputRequired('run_id and thread_id are required!') + raise InputRequired("run_id and thread_id are required!") start_time = time.perf_counter() while True: - run = cls.get(run_id, - thread_id=thread_id, - workspace=workspace, - api_key=api_key) + run = cls.get( + run_id, + thread_id=thread_id, + workspace=workspace, + api_key=api_key, + ) if run.status_code == HTTPStatus.OK: - if hasattr(run, 'status'): + if hasattr(run, "status"): if run.status in [ - 'cancelled', 'failed', 'completed', 'expired', - 'requires_action' + "cancelled", + "failed", + "completed", + "expired", + "requires_action", ]: break - else: - time_eclipsed = time.perf_counter() - start_time - if time_eclipsed > timeout_seconds: - raise TimeoutException('Wait run complete timeout') - time.sleep(1) + time_eclipsed = ( + time.perf_counter() - start_time + ) # pylint: disable=no-else-break + if time_eclipsed > timeout_seconds: + raise TimeoutException("Wait run complete timeout") + time.sleep(1) else: - logger.error('run has no status') + logger.error("run has no status") break else: logger.error( - 'Get run thread_id: %s, run_id: %s failed, message: %s' % - (thread_id, run_id, run.message)) + "Get run thread_id: %s, run_id: %s failed, message: %s", + thread_id, + run_id, + run.message, + ) break return run @classmethod - def update(cls, - run_id: str, - *, - thread_id: str, - metadata: Optional[Dict] = None, - workspace: str = None, - api_key: str = None, - **kwargs) -> Run: + def update( # type: ignore[override] + cls, + run_id: str, + *, + thread_id: str, + metadata: Optional[Dict] = None, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> Run: """Create a run. Args: thread_id (str): The thread of the run id to be updated. run_id (str): The run id to update. model (str): The model to use. - metadata (Dict, optional): Custom key-value pairs associate with run. Defaults to None. + metadata (Dict, optional): + Custom key-value pairs associate with run. Defaults to None. workspace (str): The DashScope workspace id. - api_key (str, optional): The DashScope workspace id. Defaults to None. + api_key (str, optional): + The DashScope workspace id. Defaults to None. Raises: InputRequired: thread id and run is required. @@ -457,26 +534,29 @@ def update(cls, Run: The `Run` object. """ if not thread_id or not run_id: - raise InputRequired('thread id and run id are required!') - response = super().update(run_id, - json={'metadata': metadata}, - path='threads/%s/runs/%s' % - (thread_id, run_id), - api_key=api_key, - workspace=workspace, - flattened_output=True, - method='post', - **kwargs) + raise InputRequired("thread id and run id are required!") + response = super().update( + run_id, + json={"metadata": metadata}, + path=f"threads/{thread_id}/runs/{run_id}", + api_key=api_key, + workspace=workspace, + flattened_output=True, + method="post", + **kwargs, + ) return Run(**response) @classmethod - def cancel(cls, - run_id: str, - *, - thread_id: str, - workspace: str = None, - api_key: str = None, - **kwargs) -> Run: + def cancel( # type: ignore[override] + cls, + run_id: str, + *, + thread_id: str, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> Run: """Cancel the `Run`. Args: @@ -489,13 +569,14 @@ def cancel(cls, Run: The `Run` object. """ if not thread_id or not run_id: - raise InputRequired('thread id and run id are required!') - response = super().cancel(run_id, - path='threads/%s/runs/%s/cancel' % - (thread_id, run_id), - api_key=api_key, - workspace=workspace, - flattened_output=True, - **kwargs) + raise InputRequired("thread id and run id are required!") + response = super().cancel( + run_id, + path=f"threads/{thread_id}/runs/{run_id}/cancel", + api_key=api_key, + workspace=workspace, + flattened_output=True, + **kwargs, + ) return Run(**response) diff --git a/dashscope/threads/runs/steps.py b/dashscope/threads/runs/steps.py index 240e4e0..5299012 100644 --- a/dashscope/threads/runs/steps.py +++ b/dashscope/threads/runs/steps.py @@ -1,65 +1,73 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from dashscope.client.base_api import GetStatusMixin, ListObjectMixin from dashscope.common.error import InputRequired from dashscope.threads.thread_types import RunStep, RunStepList -__all__ = ['Steps'] +__all__ = ["Steps"] class Steps(ListObjectMixin, GetStatusMixin): - SUB_PATH = 'RUNS' # useless + SUB_PATH = "RUNS" # useless @classmethod - def list(cls, - run_id: str, - *, - thread_id: str, - limit: int = None, - order: str = None, - after: str = None, - before: str = None, - workspace: str = None, - api_key: str = None, - **kwargs) -> RunStepList: + def list( # type: ignore[override] + cls, + run_id: str, + *, + thread_id: str, + limit: int = None, + order: str = None, + after: str = None, + before: str = None, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> RunStepList: """List `RunStep` of `Run`. Args: thread_id (str): The thread id. run_id (str): The run id. - limit (int, optional): How many assistant to retrieve. Defaults to None. + limit (int, optional): + How many assistant to retrieve. Defaults to None. order (str, optional): Sort order by created_at. Defaults to None. after (str, optional): Assistant id after. Defaults to None. before (str, optional): Assistant id before. Defaults to None. - workspace (str, optional): The DashScope workspace id. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): Your DashScope api key. Defaults to None. Returns: RunList: The list of runs. """ if not run_id: - raise InputRequired('run_id is required!') + raise InputRequired("run_id is required!") response = super().list( limit=limit, order=order, after=after, before=before, - path=f'threads/{thread_id}/runs/{run_id}/steps', + path=f"threads/{thread_id}/runs/{run_id}/steps", workspace=workspace, api_key=api_key, flattened_output=True, - **kwargs) + **kwargs, + ) return RunStepList(**response) @classmethod - def retrieve(cls, - step_id: str, - *, - thread_id: str, - run_id: str, - workspace: str = None, - api_key: str = None, - **kwargs) -> RunStep: + def retrieve( + cls, + step_id: str, + *, + thread_id: str, + run_id: str, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> RunStep: """Retrieve the `RunStep`. Args: @@ -73,25 +81,28 @@ def retrieve(cls, RunStep: The `RunStep` object. """ if not thread_id or not run_id or not step_id: - raise InputRequired('thread_id, run_id and step_id are required!') + raise InputRequired("thread_id, run_id and step_id are required!") response = super().get( run_id, - path=f'threads/{thread_id}/runs/{run_id}/steps/{step_id}', + path=f"threads/{thread_id}/runs/{run_id}/steps/{step_id}", workspace=workspace, api_key=api_key, flattened_output=True, - **kwargs) + **kwargs, + ) return RunStep(**response) @classmethod - def get(cls, - step_id: str, - *, - thread_id: str, - run_id: str, - workspace: str = None, - api_key: str = None, - **kwargs) -> RunStep: + def get( # type: ignore[override] + cls, + step_id: str, + *, + thread_id: str, + run_id: str, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> RunStep: """Retrieve the `RunStep`. Args: @@ -104,9 +115,12 @@ def get(cls, Returns: RunStep: The `RunStep` object. """ - return cls.retrieve(thread_id, - run_id, - step_id, - workspace=workspace, - api_key=api_key, - **kwargs) + # pylint: disable=too-many-function-args + return cls.retrieve( # type: ignore[misc] + thread_id, + run_id, + step_id, + workspace=workspace, + api_key=api_key, + **kwargs, + ) diff --git a/dashscope/threads/thread_types.py b/dashscope/threads/thread_types.py index 2fb467f..b47ff07 100644 --- a/dashscope/threads/thread_types.py +++ b/dashscope/threads/thread_types.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. # adapter from openai sdk @@ -6,7 +7,8 @@ from typing import Dict, List, Literal, Optional, Union from dashscope.assistants.assistant_types import ( - Tool, convert_tools_dict_to_objects) + Tool, convert_tools_dict_to_objects, +) from dashscope.common.base_type import BaseList, BaseObjectMixin __all__ = [ @@ -18,7 +20,7 @@ 'MessageCreationStepDetails', 'CodeInterpreterOutputLogs', 'CodeInterpreterOutputImageImage', 'CodeInterpreterOutputImage', 'CodeInterpreter', 'CodeToolCall', 'RetrievalToolCall', 'FunctionToolCall', - 'ToolCallsStepDetails', 'RunStep', 'RunStepList' + 'ToolCallsStepDetails', 'RunStep', 'RunStepList', ] @@ -29,7 +31,7 @@ class MessageFile(BaseObjectMixin): created_at: int object: str - def __init__(self, **kwargs): + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) @@ -37,7 +39,7 @@ def __init__(self, **kwargs): class MessageFileList(BaseList): data: List[MessageFile] - def __init__(self, **kwargs): + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) @@ -58,7 +60,7 @@ class Usage(BaseObjectMixin): output_tokens: int """Output tokens used (completion).""" - def __init__(self, **kwargs): + def __init__(self, **kwargs): # pylint: disable=useless-parent-delegation super().__init__(**kwargs) @@ -66,14 +68,16 @@ def __init__(self, **kwargs): class ImageFile(BaseObjectMixin): file_id: str - def __init__(self, file_id, **kwargs): + def __init__( # pylint: disable=unused-argument + self, file_id, **kwargs, + ): super().__init__(**kwargs) @dataclass(init=False) class MessageContentImageFile(BaseObjectMixin): type: str = 'image_file' - image_file: ImageFile + image_file: ImageFile # type: ignore[misc] def __init__(self, **kwargs): self.image_file = ImageFile(**kwargs.pop(self.type, {})) @@ -144,7 +148,7 @@ class ThreadMessageDelta(BaseObjectMixin): status_code: int id: str object: str = 'thread.message.delta' - delta: ThreadMessageDeltaContent + delta: ThreadMessageDeltaContent # type: ignore[misc] def __init__(self, **kwargs): content = kwargs.pop('delta', None) @@ -176,7 +180,8 @@ def __init__(self, **kwargs): for content in input_content: if 'type' in content: content_type = MESSAGE_SUPPORT_CONTENT.get( - content['type'], None) + content['type'], None, + ) if content_type: content_list.append(content_type(**content)) else: @@ -194,17 +199,22 @@ def __init__(self, **kwargs): class ThreadMessageList(BaseList): data: List[ThreadMessage] - def __init__(self, - has_more: bool = None, - last_id: Optional[str] = None, - first_id: Optional[str] = None, - data: List[ThreadMessage] = [], - **kwargs): - super().__init__(has_more=has_more, - last_id=last_id, - first_id=first_id, - data=data, - **kwargs) + # pylint: disable=dangerous-default-value + def __init__( + self, + has_more: bool = None, + last_id: Optional[str] = None, + first_id: Optional[str] = None, + data: List[ThreadMessage] = [], + **kwargs, + ): + super().__init__( + has_more=has_more, + last_id=last_id, + first_id=first_id, + data=data, + **kwargs, + ) @dataclass(init=False) @@ -215,6 +225,7 @@ class Thread(BaseObjectMixin): metadata: Optional[object] = None object: str = 'thread' + # pylint: disable=useless-parent-delegation def __init__(self, **kwargs): super().__init__(**kwargs) @@ -225,6 +236,7 @@ class Function(BaseObjectMixin): name: str output: Optional[str] = None + # pylint: disable=useless-parent-delegation def __init__(self, **kwargs): super().__init__(**kwargs) @@ -265,7 +277,8 @@ class RequiredAction(BaseObjectMixin): def __init__(self, **kwargs): self.submit_tool_outputs = RequiredActionSubmitToolOutputs( - **kwargs.pop('submit_tool_outputs', {})) + **kwargs.pop('submit_tool_outputs', {}), + ) super().__init__(**kwargs) @@ -274,6 +287,7 @@ class LastError(BaseObjectMixin): code: Literal['server_error', 'rate_limit_exceeded'] message: str + # pylint: disable=useless-parent-delegation def __init__(self, **kwargs): super().__init__(**kwargs) @@ -302,12 +316,14 @@ class Run(BaseObjectMixin): started_at: Optional[int] = None - status: Literal['queued', 'in_progress', 'requires_action', 'cancelling', - 'cancelled', 'failed', 'completed', 'expired'] + status: Literal[ # type: ignore[misc] + 'queued', 'in_progress', 'requires_action', 'cancelling', + 'cancelled', 'failed', 'completed', 'expired', + ] - thread_id: str + thread_id: str # type: ignore[misc] - tools: List[Tool] + tools: List[Tool] # type: ignore[misc] top_p: Optional[float] = None top_k: Optional[int] = None @@ -331,23 +347,30 @@ def __init__(self, **kwargs): class RunList(BaseObjectMixin): data: List[Run] - def __init__(self, - has_more: bool = None, - last_id: Optional[str] = None, - first_id: Optional[str] = None, - data: List[Run] = [], - **kwargs): - super().__init__(has_more=has_more, - last_id=last_id, - first_id=first_id, - data=data, - **kwargs) + # pylint: disable=dangerous-default-value + def __init__( + self, + has_more: bool = None, + last_id: Optional[str] = None, + first_id: Optional[str] = None, + data: List[Run] = [], + **kwargs, + ): + super().__init__( + has_more=has_more, + last_id=last_id, + first_id=first_id, + data=data, + **kwargs, + ) @dataclass(init=False) class MessageCreation(BaseObjectMixin): message_id: str """The ID of the message that was created by this run step.""" + + # pylint: disable=useless-parent-delegation def __init__(self, **kwargs): super().__init__(**kwargs) @@ -358,6 +381,8 @@ class MessageCreationStepDetails(BaseObjectMixin): type: Literal['message_creation'] """Always `message_creation`.""" + + # pylint: disable=useless-parent-delegation def __init__(self, **kwargs): super().__init__(**kwargs) @@ -369,6 +394,8 @@ class CodeInterpreterOutputLogs(BaseObjectMixin): type: Literal['logs'] """Always `logs`.""" + + # pylint: disable=useless-parent-delegation def __init__(self, **kwargs): super().__init__(**kwargs) @@ -380,6 +407,8 @@ class CodeInterpreterOutputImageImage(BaseObjectMixin): The [file](https://platform.openai.com/docs/api-reference/files) ID of the image. """ + + # pylint: disable=useless-parent-delegation def __init__(self, **kwargs): super().__init__(**kwargs) @@ -395,8 +424,10 @@ def __init__(self, **kwargs): super().__init__(**kwargs) -CodeInterpreterOutput = Union[CodeInterpreterOutputLogs, - CodeInterpreterOutputImage] +CodeInterpreterOutput = Union[ + CodeInterpreterOutputLogs, + CodeInterpreterOutputImage, +] @dataclass(init=False) @@ -407,7 +438,7 @@ class CodeInterpreter(BaseObjectMixin): outputs: List[CodeInterpreterOutput] """The outputs from the Code Interpreter tool call. - Code Interpreter can output one or more items, including text (`logs`) or images + Code Interpreter can output one or more items, including text (`logs`) or images # noqa: E501 (`image`). Each of these are represented by a different object type. """ def __init__(self, **kwargs): @@ -433,7 +464,8 @@ class CodeToolCall(BaseObjectMixin): """ def __init__(self, **kwargs): self.code_interpreter = CodeInterpreter( - **kwargs.pop('code_interpreter', {})) + **kwargs.pop('code_interpreter', {}), + ) super().__init__(**kwargs) @@ -450,6 +482,8 @@ class RetrievalToolCall(BaseObjectMixin): This is always going to be `quark_search` for this type of tool call. """ + + # pylint: disable=useless-parent-delegation def __init__(self, **kwargs): super().__init__(**kwargs) @@ -500,7 +534,7 @@ class ToolCallsStepDetails(BaseObjectMixin): tool_calls: List[ToolCall] """An array of tool calls the run step was involved in. - These can be associated with one of three types of tools: `code_interpreter`, + These can be associated with one of three types of tools: `code_interpreter`, # noqa: E501 `retrieval`, or `function`. """ @@ -508,7 +542,8 @@ class ToolCallsStepDetails(BaseObjectMixin): """Always `tool_calls`.""" def __init__(self, **kwargs): self.tool_calls = convert_tool_calls_dict_to_object( - kwargs.pop('tool_calls', [])) + kwargs.pop('tool_calls', []), + ) super().__init__(**kwargs) @@ -533,7 +568,8 @@ class RunStepDeltaContent(BaseObjectMixin): def __init__(self, **kwargs): self.step_details = convert_step_details_dict_to_objects( - kwargs.pop('step_details', {})) + kwargs.pop('step_details', {}), + ) super().__init__(**kwargs) @@ -541,7 +577,7 @@ def __init__(self, **kwargs): class RunStepDelta(BaseObjectMixin): id: str object: str = 'thread.run.step.delta' - delta: RunStepDeltaContent + delta: RunStepDeltaContent # type: ignore[misc] def __init__(self, **kwargs): delta = kwargs.pop('delta', None) @@ -555,10 +591,10 @@ def __init__(self, **kwargs): @dataclass(init=False) class RunStep(BaseObjectMixin): status_code: int = None - id: str - """The identifier of the run step, which can be referenced in API endpoints.""" + id: str # type: ignore[misc] + """The identifier of the run step, which can be referenced in API endpoints.""" # noqa: E501 - assistant_id: str + assistant_id: str # type: ignore[misc] """ The ID of the [assistant](https://platform.openai.com/docs/api-reference/assistants) @@ -571,7 +607,7 @@ class RunStep(BaseObjectMixin): completed_at: Optional[int] = None """The Unix timestamp (in seconds) for when the run step completed.""" - created_at: int + created_at: int # type: ignore[misc] """The Unix timestamp (in seconds) for when the run step was created.""" expired_at: Optional[int] = None @@ -593,52 +629,54 @@ class RunStep(BaseObjectMixin): """Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a - structured format. Keys can be a maximum of 64 characters long and values can be + structured format. Keys can be a maximum of 64 characters long and values can be # noqa: E501 a maxium of 512 characters long. """ - object: Literal['thread.run.step'] + object: Literal['thread.run.step'] # type: ignore[misc] """The object type, which is always `thread.run.step`.""" - run_id: str + run_id: str # type: ignore[misc] """ - The ID of the [run](https://platform.openai.com/docs/api-reference/runs) that + The ID of the [run](https://platform.openai.com/docs/api-reference/runs) that # noqa: E501 this run step is a part of. """ - status: Literal['in_progress', 'cancelled', 'failed', 'completed', - 'expired'] + status: Literal[ # type: ignore[misc] + 'in_progress', 'cancelled', 'failed', 'completed', + 'expired', + ] """ The status of the run step, which can be either `in_progress`, `cancelled`, `failed`, `completed`, or `expired`. """ - step_details: StepDetails + step_details: StepDetails # type: ignore[misc] """The details of the run step.""" - thread_id: str + thread_id: str # type: ignore[misc] """ - The ID of the [thread](https://platform.openai.com/docs/api-reference/threads) + The ID of the [thread](https://platform.openai.com/docs/api-reference/threads) # noqa: E501 that was run. """ - type: Literal['message_creation', 'tool_calls'] - """The type of run step, which can be either `message_creation` or `tool_calls`.""" + type: Literal['message_creation', 'tool_calls'] # type: ignore[misc] + """The type of run step, which can be either `message_creation` or `tool_calls`.""" # noqa: E501 # pylint: disable=line-too-long usage: Optional[Usage] = None def __init__(self, **kwargs): self.step_details = convert_step_details_dict_to_objects( - kwargs.pop('step_details', {})) - if 'usage' in kwargs and kwargs['usage'] is not None and kwargs['usage']: + kwargs.pop('step_details', {}), + ) + if 'usage' in kwargs and kwargs['usage'] is not None and kwargs['usage']: # noqa: E501 self.usage = Usage(**kwargs.pop('usage', {})) else: self.usage = None last_error = kwargs.pop('last_error', None) if last_error: self.last_error = LastError(**last_error) - else: - last_error = last_error + super().__init__(**kwargs) @@ -646,20 +684,25 @@ def __init__(self, **kwargs): class RunStepList(BaseList): data: List[RunStep] - def __init__(self, - has_more: bool = None, - last_id: Optional[str] = None, - first_id: Optional[str] = None, - data: List[RunStep] = [], - **kwargs): + # pylint: disable=dangerous-default-value + def __init__( + self, + has_more: bool = None, + last_id: Optional[str] = None, + first_id: Optional[str] = None, + data: List[RunStep] = [], + **kwargs, + ): if data: steps = [] for step in data: - steps.append(RunStep(**step)) + steps.append(RunStep(**step)) # type: ignore[arg-type] self.data = steps else: self.data = [] - super().__init__(has_more=has_more, - last_id=last_id, - first_id=first_id, - **kwargs) + super().__init__( + has_more=has_more, + last_id=last_id, + first_id=first_id, + **kwargs, + ) diff --git a/dashscope/threads/threads.py b/dashscope/threads/threads.py index 5430573..066232d 100644 --- a/dashscope/threads/threads.py +++ b/dashscope/threads/threads.py @@ -1,58 +1,77 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Dict, List, Optional from dashscope.assistants.assistant_types import DeleteResponse -from dashscope.client.base_api import (CreateMixin, DeleteMixin, - GetStatusMixin, UpdateMixin) +from dashscope.client.base_api import ( + CreateMixin, + DeleteMixin, + GetStatusMixin, + UpdateMixin, +) from dashscope.common.error import InputRequired from dashscope.threads.thread_types import Run, Thread -__all__ = ['Threads'] +__all__ = ["Threads"] class Threads(CreateMixin, DeleteMixin, GetStatusMixin, UpdateMixin): - SUB_PATH = 'threads' + SUB_PATH = "threads" @classmethod - def call(cls, - *, - messages: List[Dict] = None, - metadata: Dict = None, - workspace: str = None, - api_key: str = None, - **kwargs) -> Thread: + def call( # type: ignore[override] + cls, + *, + messages: List[Dict] = None, + metadata: Dict = None, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> Thread: """Create a thread. Args: - messages (List[Dict], optional): List of messages to start thread. Defaults to None. - metadata (Dict, optional): The key-value information associate with thread. Defaults to None. - workspace (str, optional): The DashScope workspace id. Defaults to None. + messages (List[Dict], optional): + List of messages to start thread. Defaults to None. + metadata (Dict, optional): + The key-value information associate with thread. Defaults to + None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): Your DashScope api key. Defaults to None. Returns: Thread: The thread object. """ - return cls.create(messages=messages, - metadata=metadata, - workspace=workspace, - api_key=api_key, - **kwargs) + return cls.create( + messages=messages, + metadata=metadata, + workspace=workspace, + api_key=api_key, + **kwargs, + ) @classmethod - def create(cls, - *, - messages: List[Dict] = None, - metadata: Dict = None, - workspace: str = None, - api_key: str = None, - **kwargs) -> Thread: + def create( + cls, + *, + messages: List[Dict] = None, + metadata: Dict = None, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> Thread: """Create a thread. Args: - messages (List[Dict], optional): List of messages to start thread. Defaults to None. - metadata (Dict, optional): The key-value information associate with thread. Defaults to None. - workspace (str, optional): The DashScope workspace id. Defaults to None. + messages (List[Dict], optional): + List of messages to start thread. Defaults to None. + metadata (Dict, optional): + The key-value information associate with thread. Defaults to + None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): Your DashScope api key. Defaults to None. Returns: @@ -60,153 +79,180 @@ def create(cls, """ data = {} if messages: - data['messages'] = messages + data["messages"] = messages if metadata: - data['metadata'] = metadata - response = super().call(data=data if data else '', - api_key=api_key, - flattened_output=True, - workspace=workspace, - **kwargs) + data["metadata"] = metadata + response = super().call( + data=data if data else "", + api_key=api_key, + flattened_output=True, + workspace=workspace, + **kwargs, + ) return Thread(**response) @classmethod - def get(cls, - thread_id: str, - *, - workspace: str = None, - api_key: str = None, - **kwargs) -> Thread: + def get( # type: ignore[override] + cls, + thread_id: str, + *, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> Thread: """Retrieve the thread. Args: thread_id (str): The target thread. - workspace (str, optional): The DashScope workspace id. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): Your DashScope api key. Defaults to None. Returns: Thread: The `Thread` information. """ - return cls.retrieve(thread_id, - workspace=workspace, - api_key=api_key, - **kwargs) + return cls.retrieve( + thread_id, + workspace=workspace, + api_key=api_key, + **kwargs, + ) @classmethod - def retrieve(cls, - thread_id: str, - *, - workspace: str = None, - api_key: str = None, - **kwargs) -> Thread: + def retrieve( + cls, + thread_id: str, + *, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> Thread: """Retrieve the thread. Args: thread_id (str): The target thread. - workspace (str, optional): The DashScope workspace id. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): Your DashScope api key. Defaults to None. Returns: Thread: The `Thread` information. """ if not thread_id: - raise InputRequired('thread_id is required!') - response = super().get(thread_id, - api_key=api_key, - flattened_output=True, - workspace=workspace, - **kwargs) + raise InputRequired("thread_id is required!") + response = super().get( + thread_id, + api_key=api_key, + flattened_output=True, + workspace=workspace, + **kwargs, + ) return Thread(**response) @classmethod - def update(cls, - thread_id: str, - *, - metadata: Dict = None, - workspace: str = None, - api_key: str = None, - **kwargs) -> Thread: + def update( # type: ignore[override] + cls, + thread_id: str, + *, + metadata: Dict = None, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> Thread: """Update thread information. Args: thread_id (str): The thread id. - metadata (Dict, optional): The thread key-value information. Defaults to None. - workspace (str, optional): The DashScope workspace id. Defaults to None. + metadata (Dict, optional): + The thread key-value information. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): Your DashScope api key. Defaults to None. Returns: Thread: The `Thread` information. """ if not thread_id: - raise InputRequired('thread_id is required!') - response = super().update(thread_id, - json={'metadata': metadata}, - api_key=api_key, - workspace=workspace, - flattened_output=True, - method='post', - **kwargs) + raise InputRequired("thread_id is required!") + response = super().update( + thread_id, + json={"metadata": metadata}, + api_key=api_key, + workspace=workspace, + flattened_output=True, + method="post", + **kwargs, + ) return Thread(**response) @classmethod - def delete(cls, - thread_id, - *, - workspace: str = None, - api_key: str = None, - **kwargs) -> DeleteResponse: + def delete( # type: ignore[override] + cls, + thread_id, + *, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> DeleteResponse: """Delete thread. Args: thread_id (str): The thread id to delete. - workspace (str, optional): The DashScope workspace id. Defaults to None. + workspace (str, optional): + The DashScope workspace id. Defaults to None. api_key (str, optional): Your DashScope api key. Defaults to None. Returns: AssistantsDeleteResponse: The deleted information. """ if not thread_id: - raise InputRequired('thread_id is required!') - response = super().delete(thread_id, - api_key=api_key, - workspace=workspace, - flattened_output=True, - **kwargs) + raise InputRequired("thread_id is required!") + response = super().delete( + thread_id, + api_key=api_key, + workspace=workspace, + flattened_output=True, + **kwargs, + ) return DeleteResponse(**response) @classmethod - def create_and_run(cls, - *, - assistant_id: str, - thread: Optional[Dict] = None, - model: Optional[str] = None, - instructions: Optional[str] = None, - additional_instructions: Optional[str] = None, - tools: Optional[List[Dict]] = None, - metadata: Optional[Dict] = None, - workspace: str = None, - api_key: str = None, - **kwargs) -> Run: + def create_and_run( + cls, + *, + assistant_id: str, + thread: Optional[Dict] = None, + model: Optional[str] = None, + instructions: Optional[str] = None, + additional_instructions: Optional[str] = None, + tools: Optional[List[Dict]] = None, + metadata: Optional[Dict] = None, + workspace: str = None, + api_key: str = None, + **kwargs, + ) -> Run: if not assistant_id: - raise InputRequired('assistant_id is required') - data = {'assistant_id': assistant_id} + raise InputRequired("assistant_id is required") + data = {"assistant_id": assistant_id} if thread: - data['thread'] = thread + data["thread"] = thread if model: - data['model'] = model + data["model"] = model if instructions: - data['instructions'] = instructions + data["instructions"] = instructions if additional_instructions: - data['additional_instructions'] = additional_instructions + data["additional_instructions"] = additional_instructions if tools: - data['tools'] = tools + data["tools"] = tools if metadata: - data['metadata'] = metadata - - response = super().call(data=data, - path='threads/runs', - api_key=api_key, - flattened_output=True, - workspace=workspace, - **kwargs) + data["metadata"] = metadata + + response = super().call( + data=data, + path="threads/runs", + api_key=api_key, + flattened_output=True, + workspace=workspace, + **kwargs, + ) return Run(**response) diff --git a/dashscope/tokenizers/__init__.py b/dashscope/tokenizers/__init__.py index b63f34e..e43a8af 100644 --- a/dashscope/tokenizers/__init__.py +++ b/dashscope/tokenizers/__init__.py @@ -1,7 +1,13 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from .tokenization import Tokenization from .tokenizer import get_tokenizer, list_tokenizers from .tokenizer_base import Tokenizer -__all__ = [Tokenization, Tokenizer, get_tokenizer, list_tokenizers] +__all__ = [ + "Tokenization", + "Tokenizer", + "get_tokenizer", + "list_tokenizers", +] diff --git a/dashscope/tokenizers/qwen_tokenizer.py b/dashscope/tokenizers/qwen_tokenizer.py index 78a1599..dcfc2a9 100644 --- a/dashscope/tokenizers/qwen_tokenizer.py +++ b/dashscope/tokenizers/qwen_tokenizer.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import base64 @@ -6,55 +7,61 @@ from .tokenizer_base import Tokenizer -PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # noqa E501 -ENDOFTEXT = '<|endoftext|>' -IMSTART = '<|im_start|>' -IMEND = '<|im_end|>' +PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # noqa E501 # pylint: disable=line-too-long +ENDOFTEXT = "<|endoftext|>" +IMSTART = "<|im_start|>" +IMEND = "<|im_end|>" # as the default behavior is changed to allow special tokens in # regular texts, the surface forms of special tokens need to be # as different as possible to minimize the impact -EXTRAS = tuple((f'<|extra_{i}|>' for i in range(205))) -# changed to use actual index to avoid misconfiguration with vocabulary expansion +EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205))) +# changed to use actual index to avoid misconfiguration with vocabulary expansion # noqa: E501 SPECIAL_START_ID = 151643 SPECIAL_TOKENS = tuple( enumerate( - (( - ENDOFTEXT, - IMSTART, - IMEND, - ) + EXTRAS), + ( + ( + ENDOFTEXT, + IMSTART, + IMEND, + ) + + EXTRAS + ), start=SPECIAL_START_ID, - )) + ), +) SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS) class QwenTokenizer(Tokenizer): @staticmethod def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: - with open(tiktoken_bpe_file, 'rb') as f: + with open(tiktoken_bpe_file, "rb") as f: contents = f.read() return { base64.b64decode(token): int(rank) - for token, rank in (line.split() for line in contents.splitlines() - if line) + for token, rank in ( + line.split() for line in contents.splitlines() if line + ) } - def __init__(self, vocab_file, errors='replace', extra_vocab_file=None): + def __init__(self, vocab_file, errors="replace", extra_vocab_file=None): self._errors = errors self._vocab_file = vocab_file self._extra_vocab_file = extra_vocab_file self._mergeable_ranks = QwenTokenizer._load_tiktoken_bpe( - vocab_file) # type: Dict[bytes, int] + vocab_file, + ) # type: Dict[bytes, int] self._special_tokens = { - token: index - for index, token in SPECIAL_TOKENS + token: index for index, token in SPECIAL_TOKENS } # try load extra vocab from file if extra_vocab_file is not None: used_ids = set(self._mergeable_ranks.values()) | set( - self._special_tokens.values()) + self._special_tokens.values(), + ) extra_mergeable_ranks = self._load_tiktoken_bpe(extra_vocab_file) for token, index in extra_mergeable_ranks.items(): if token in self._mergeable_ranks: @@ -62,22 +69,23 @@ def __init__(self, vocab_file, errors='replace', extra_vocab_file=None): if index in used_ids: continue self._mergeable_ranks[token] = index - # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this + # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this # noqa: E501 # pylint: disable=line-too-long import tiktoken + enc = tiktoken.Encoding( - 'Qwen', + "Qwen", pat_str=PAT_STR, mergeable_ranks=self._mergeable_ranks, special_tokens=self._special_tokens, ) assert ( - len(self._mergeable_ranks) + - len(self._special_tokens) == enc.n_vocab - ), f'{len(self._mergeable_ranks) + len(self._special_tokens)} != {enc.n_vocab} in encoding' + len(self._mergeable_ranks) + len(self._special_tokens) + == enc.n_vocab + ), f"{len(self._mergeable_ranks) + len(self._special_tokens)} != {enc.n_vocab} in encoding" # noqa: E501 # pylint: disable=line-too-long - self.decoder = {v: k - for k, v in self._mergeable_ranks.items() - } # type: dict[int, bytes|str] + self.decoder = { + v: k for k, v in self._mergeable_ranks.items() + } # type: dict[int, bytes|str] self.decoder.update({v: k for k, v in self._special_tokens.items()}) self._tokenizer = enc # type: tiktoken.Encoding @@ -86,16 +94,18 @@ def __init__(self, vocab_file, errors='replace', extra_vocab_file=None): self.im_start_id = self._special_tokens[IMSTART] self.im_end_id = self._special_tokens[IMEND] - def encode( + def encode( # type: ignore[override] self, text: str, - allowed_special: Union[Set, str] = 'all', + allowed_special: Union[Set, str] = "all", disallowed_special: Union[Collection, str] = (), ) -> Union[List[List], List]: - text = unicodedata.normalize('NFC', text) - return self._tokenizer.encode(text, - allowed_special=allowed_special, - disallowed_special=disallowed_special) + text = unicodedata.normalize("NFC", text) + return self._tokenizer.encode( + text, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) def decode( self, diff --git a/dashscope/tokenizers/tokenization.py b/dashscope/tokenizers/tokenization.py index dd65486..ffb27ba 100644 --- a/dashscope/tokenizers/tokenization.py +++ b/dashscope/tokenizers/tokenization.py @@ -1,45 +1,57 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import copy from typing import Any, List -from dashscope.api_entities.dashscope_response import (DashScopeAPIResponse, - Message, Role) +from dashscope.api_entities.dashscope_response import ( + DashScopeAPIResponse, + Message, + Role, +) from dashscope.client.base_api import BaseApi -from dashscope.common.constants import (CUSTOMIZED_MODEL_ID, - DEPRECATED_MESSAGE, HISTORY, MESSAGES, - PROMPT) +from dashscope.common.constants import ( + CUSTOMIZED_MODEL_ID, + DEPRECATED_MESSAGE, + HISTORY, + MESSAGES, + PROMPT, +) from dashscope.common.error import InputRequired, ModelRequired from dashscope.common.logging import logger class Tokenization(BaseApi): - FUNCTION = 'tokenizer' + FUNCTION = "tokenizer" """API for get tokenizer result.. """ + class Models: - """List of models currently supported - """ - qwen_turbo = 'qwen-turbo' - qwen_plus = 'qwen-plus' - qwen_7b_chat = 'qwen-7b-chat' - qwen_14b_chat = 'qwen-14b-chat' - llama2_7b_chat_v2 = 'llama2-7b-chat-v2' - llama2_13b_chat_v2 = 'llama2-13b-chat-v2' - text_embedding_v2 = 'text-embedding-v2' - qwen_72b_chat = 'qwen-72b-chat' + """List of models currently supported""" + + qwen_turbo = "qwen-turbo" + qwen_plus = "qwen-plus" + qwen_7b_chat = "qwen-7b-chat" + qwen_14b_chat = "qwen-14b-chat" + llama2_7b_chat_v2 = "llama2-7b-chat-v2" + llama2_13b_chat_v2 = "llama2-13b-chat-v2" + text_embedding_v2 = "text-embedding-v2" + qwen_72b_chat = "qwen-72b-chat" @classmethod - def call(cls, - model: str, - input: Any = None, - prompt: Any = None, - history: list = None, - api_key: str = None, - messages: List[Message] = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + # type: ignore[override] + def call( # pylint: disable=arguments-renamed # noqa: E501 + cls, + model: str, + input: Any = None, # pylint: disable=redefined-builtin + prompt: Any = None, + history: list = None, + api_key: str = None, + messages: List[Message] = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Call tokenization. Args: @@ -70,34 +82,43 @@ def call(cls, Returns: DashScopeAPIResponse: The tokenizer output. """ - if (input is None or not input) and \ - (prompt is None or not prompt) and \ - (messages is None or not messages): - raise InputRequired('prompt or messages is required!') + if ( # pylint: disable=too-many-boolean-expressions + (input is None or not input) + and (prompt is None or not prompt) + and (messages is None or not messages) + ): + raise InputRequired("prompt or messages is required!") if model is None or not model: - raise ModelRequired('Model is required!') + raise ModelRequired("Model is required!") if input is None: input, parameters = cls._build_llm_parameters( - model, prompt, history, messages, **kwargs) + model, + prompt, + history, + messages, + **kwargs, + ) else: parameters = kwargs - if kwargs.pop('stream', False): # not support stream - logger.warning('streaming option not supported for tokenization.') + if kwargs.pop("stream", False): # not support stream + logger.warning("streaming option not supported for tokenization.") - return super().call(model=model, - task_group=None, - function=cls.FUNCTION, - api_key=api_key, - input=input, - is_service=False, - workspace=workspace, - **parameters) + return super().call( + model=model, + task_group=None, # type: ignore[arg-type] + function=cls.FUNCTION, + api_key=api_key, + input=input, + is_service=False, + workspace=workspace, + **parameters, + ) @classmethod def _build_llm_parameters(cls, model, prompt, history, messages, **kwargs): parameters = {} - input = {} + input = {} # pylint: disable=redefined-builtin if history is not None: logger.warning(DEPRECATED_MESSAGE) input[HISTORY] = history @@ -106,20 +127,21 @@ def _build_llm_parameters(cls, model, prompt, history, messages, **kwargs): elif messages is not None: msgs = copy.deepcopy(messages) if prompt is not None and prompt: - msgs.append({'role': Role.USER, 'content': prompt}) + msgs.append({"role": Role.USER, "content": prompt}) input = {MESSAGES: msgs} else: input[PROMPT] = prompt - if model.startswith('qwen'): - enable_search = kwargs.pop('enable_search', False) + if model.startswith("qwen"): + enable_search = kwargs.pop("enable_search", False) if enable_search: - parameters['enable_search'] = enable_search - elif model.startswith('bailian'): - customized_model_id = kwargs.pop('customized_model_id', None) + parameters["enable_search"] = enable_search + elif model.startswith("bailian"): + customized_model_id = kwargs.pop("customized_model_id", None) if customized_model_id is None: - raise InputRequired('customized_model_id is required for %s' % - model) + raise InputRequired( + f"customized_model_id is required for {model}", + ) input[CUSTOMIZED_MODEL_ID] = customized_model_id return input, {**parameters, **kwargs} diff --git a/dashscope/tokenizers/tokenizer.py b/dashscope/tokenizers/tokenizer.py index 62da0f9..05ba3b0 100644 --- a/dashscope/tokenizers/tokenizer.py +++ b/dashscope/tokenizers/tokenizer.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import os @@ -8,7 +9,7 @@ from .tokenizer_base import Tokenizer -QWEN_SERIALS = ['qwen-7b-chat', 'qwen-turbo', 'qwen-plus', 'qwen-max'] +QWEN_SERIALS = ["qwen-7b-chat", "qwen-turbo", "qwen-plus", "qwen-max"] current_path = os.path.dirname(os.path.abspath(__file__)) root_path = os.path.dirname(current_path) @@ -27,13 +28,16 @@ def get_tokenizer(model: str) -> Tokenizer: """ if model in QWEN_SERIALS: return QwenTokenizer( - os.path.join(root_path, 'resources', 'qwen.tiktoken')) - elif model.startswith('qwen'): + os.path.join(root_path, "resources", "qwen.tiktoken"), + ) + elif model.startswith("qwen"): return QwenTokenizer( - os.path.join(root_path, 'resources', 'qwen.tiktoken')) + os.path.join(root_path, "resources", "qwen.tiktoken"), + ) else: raise UnsupportedModel( - f'Not support model: {model}, currently only support qwen models.') + f"Not support model: {model}, currently only support qwen models.", + ) def list_tokenizers() -> List[str]: diff --git a/dashscope/tokenizers/tokenizer_base.py b/dashscope/tokenizers/tokenizer_base.py index fdc563a..2e385bf 100644 --- a/dashscope/tokenizers/tokenizer_base.py +++ b/dashscope/tokenizers/tokenizer_base.py @@ -1,15 +1,16 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from typing import List class Tokenizer: - """Base tokenizer interface for local tokenizers. - """ + """Base tokenizer interface for local tokenizers.""" + def __init__(self): pass - def encode(self, text: str, **kwargs) -> List[int]: + def encode(self, text: str, **kwargs) -> List[int]: # type: ignore[empty-body] # noqa: E501 """Encode input text string to token ids. Args: @@ -18,9 +19,8 @@ def encode(self, text: str, **kwargs) -> List[int]: Returns: List[int]: The token ids. """ - pass - def decode(self, token_ids: List[int], **kwargs) -> str: + def decode(self, token_ids: List[int], **kwargs) -> str: # type: ignore[empty-body] # pylint: disable=line-too-long # noqa: E501 """Decode token ids to string. Args: @@ -29,4 +29,3 @@ def decode(self, token_ids: List[int], **kwargs) -> str: Returns: str: The string of the token ids. """ - pass diff --git a/dashscope/utils/message_utils.py b/dashscope/utils/message_utils.py index 306c3f4..a11d66a 100644 --- a/dashscope/utils/message_utils.py +++ b/dashscope/utils/message_utils.py @@ -1,7 +1,15 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import copy -def merge_single_response(parsed_response, accumulated_data, n=1): + +# pylint: disable=R1702,too-many-branches,too-many-return-statements +# pylint: disable=too-many-statements +def merge_single_response( # noqa: E501 + parsed_response, + accumulated_data, + n=1, +): """Merge a single response chunk with accumulated data. Args: @@ -16,54 +24,75 @@ def merge_single_response(parsed_response, accumulated_data, n=1): """ # Check if all choices have been sent (for n > 1 case) if n > 1 and accumulated_data: - all_sent = all(data.get('all_choices_sent', False) - for data in accumulated_data.values() - if isinstance(data, dict) and 'all_choices_sent' in data) + all_sent = all( + data.get("all_choices_sent", False) + for data in accumulated_data.values() + if isinstance(data, dict) and "all_choices_sent" in data + ) if all_sent: return False # Track usage for each choice index when n > 1 # Each streaming packet contains usage info for one specific choice - if (n > 1 and parsed_response.usage and - parsed_response.output and parsed_response.output.choices and - len(parsed_response.output.choices) > 0): - if 'usage_by_index' not in accumulated_data: - accumulated_data['usage_by_index'] = {} - - # Get the choice index from the first (and typically only) choice in this packet + if ( + n > 1 + and parsed_response.usage + and parsed_response.output + and parsed_response.output.choices + and len(parsed_response.output.choices) > 0 + ): + if "usage_by_index" not in accumulated_data: + accumulated_data["usage_by_index"] = {} + + # Get the choice index from the first (and typically only) choice in this packet # noqa: E501 # pylint: disable=line-too-long try: first_choice = parsed_response.output.choices[0] - choice_idx = first_choice.index if hasattr( - first_choice, 'index') and 'index' in first_choice else 0 + choice_idx = ( + first_choice.index + if hasattr( + first_choice, + "index", + ) + and "index" in first_choice + else 0 + ) # Store only output_tokens for this choice index - if 'output_tokens' in parsed_response.usage: - accumulated_data['usage_by_index'][choice_idx] = dict( - parsed_response.usage) + if "output_tokens" in parsed_response.usage: + accumulated_data["usage_by_index"][choice_idx] = dict( + parsed_response.usage, + ) except (KeyError, AttributeError, IndexError): pass # Handle output.text accumulation when choices is null - if (parsed_response.output and - hasattr(parsed_response.output, 'text') and - (not parsed_response.output.choices or parsed_response.output.choices is None)): + if ( + parsed_response.output + and hasattr(parsed_response.output, "text") + and ( + not parsed_response.output.choices + or parsed_response.output.choices is None + ) + ): choice_idx = 0 if choice_idx not in accumulated_data: accumulated_data[choice_idx] = { - 'content': '', - 'reasoning_content': '', - 'tool_calls': [], - 'logprobs': {'content': []}, - 'finished': False, - 'finish_reason': None, - 'all_choices_sent': False, - 'role': None + "content": "", + "reasoning_content": "", + "tool_calls": [], + "logprobs": {"content": []}, + "finished": False, + "finish_reason": None, + "all_choices_sent": False, + "role": None, } # Accumulate text if not empty if parsed_response.output.text: - accumulated_data[choice_idx]['content'] += parsed_response.output.text + accumulated_data[choice_idx][ + "content" + ] += parsed_response.output.text # Always set accumulated content back to response - parsed_response.output.text = accumulated_data[choice_idx]['content'] + parsed_response.output.text = accumulated_data[choice_idx]["content"] return True # Process each choice in the choices array @@ -77,197 +106,336 @@ def merge_single_response(parsed_response, accumulated_data, n=1): for choice_enum_idx, choice in enumerate(choices): # Use choice.index if available, otherwise use enumerate index try: - choice_idx = choice.index if hasattr(choice, 'index') and 'index' in choice else choice_enum_idx + choice_idx = ( + choice.index + if hasattr(choice, "index") and "index" in choice + else choice_enum_idx + ) except (KeyError, AttributeError): choice_idx = choice_enum_idx # Initialize accumulated data for this choice if not exists if choice_idx not in accumulated_data: accumulated_data[choice_idx] = { - 'content': '', - 'reasoning_content': '', - 'tool_calls': [], - 'logprobs': {'content': []}, - 'finished': False, - 'finish_reason': None, - 'all_choices_sent': False, - 'role': None + "content": "", + "reasoning_content": "", + "tool_calls": [], + "logprobs": {"content": []}, + "finished": False, + "finish_reason": None, + "all_choices_sent": False, + "role": None, } # Handle message field - create if null if not choice.message: # Create message object with accumulated data choice.message = { - 'role': accumulated_data[choice_idx]['role'] if accumulated_data[choice_idx]['role'] else 'assistant', - 'content': accumulated_data[choice_idx]['content'] + "role": accumulated_data[choice_idx]["role"] + if accumulated_data[choice_idx]["role"] + else "assistant", + "content": accumulated_data[choice_idx]["content"], } - if accumulated_data[choice_idx]['reasoning_content']: - choice.message['reasoning_content'] = accumulated_data[choice_idx]['reasoning_content'] - if accumulated_data[choice_idx]['tool_calls']: - choice.message['tool_calls'] = accumulated_data[choice_idx]['tool_calls'] + if accumulated_data[choice_idx]["reasoning_content"]: + choice.message["reasoning_content"] = accumulated_data[ + choice_idx + ]["reasoning_content"] + if accumulated_data[choice_idx]["tool_calls"]: + choice.message["tool_calls"] = accumulated_data[ + choice_idx + ]["tool_calls"] else: # Save role if present - if hasattr(choice.message, 'role') and choice.message.role: - accumulated_data[choice_idx]['role'] = choice.message.role + if hasattr(choice.message, "role") and choice.message.role: + accumulated_data[choice_idx]["role"] = choice.message.role # Handle content accumulation - if 'content' in choice.message: + if "content" in choice.message: current_content = choice.message.content if current_content: # Check if content is multimodal format if isinstance(current_content, list): # Handle multimodal content (array format) - # Initialize accumulated content as array if not already - if not isinstance(accumulated_data[choice_idx]['content'], list): - accumulated_data[choice_idx]['content'] = [] - - # Ensure accumulated content list has enough elements - while len(accumulated_data[choice_idx]['content']) < len(current_content): - accumulated_data[choice_idx]['content'].append({'text': ''}) + # Initialize accumulated content as array if not already # noqa: E501 + if not isinstance( + accumulated_data[choice_idx]["content"], + list, + ): + accumulated_data[choice_idx]["content"] = [] + + # Ensure accumulated content list has enough elements # noqa: E501 + while len( + accumulated_data[choice_idx]["content"], + ) < len(current_content): + accumulated_data[choice_idx]["content"].append( + {"text": ""}, + ) # Merge each content element - for content_idx, content_item in enumerate(current_content): - if isinstance(content_item, dict) and 'text' in content_item: - if content_item['text']: + for content_idx, content_item in enumerate( + current_content, + ): + if ( + isinstance(content_item, dict) + and "text" in content_item + ): + if content_item["text"]: # Accumulate text content - accumulated_data[choice_idx]['content'][content_idx]['text'] += content_item['text'] - # Update the current response with accumulated content - for content_idx in range(len(accumulated_data[choice_idx]['content'])): + accumulated_data[choice_idx][ + "content" + ][content_idx]["text"] += content_item[ + "text" + ] + # Update the current response with accumulated content # noqa: E501 + for content_idx in range( + len(accumulated_data[choice_idx]["content"]), + ): if content_idx < len(choice.message.content): - choice.message.content[content_idx]['text'] = accumulated_data[choice_idx]['content'][content_idx]['text'] + choice.message.content[content_idx][ + "text" + ] = accumulated_data[choice_idx][ + "content" + ][ + content_idx + ][ + "text" + ] else: # Handle regular content (string format) # Initialize accumulated content as string - if isinstance(accumulated_data[choice_idx]['content'], list): - accumulated_data[choice_idx]['content'] = '' + if isinstance( + accumulated_data[choice_idx]["content"], + list, + ): + accumulated_data[choice_idx]["content"] = "" # Accumulate content if not empty - accumulated_data[choice_idx]['content'] += current_content + accumulated_data[choice_idx][ + "content" + ] += current_content # Always set accumulated content back to response - if not isinstance(accumulated_data[choice_idx]['content'], list): - choice.message.content = accumulated_data[choice_idx]['content'] + if not isinstance( + accumulated_data[choice_idx]["content"], + list, + ): + choice.message.content = accumulated_data[choice_idx][ + "content" + ] else: # For multimodal content, ensure message.content # exists if not isinstance(choice.message.content, list): - choice.message.content = accumulated_data[choice_idx]['content'] + choice.message.content = accumulated_data[ + choice_idx + ]["content"] # Handle reasoning_content accumulation - if 'reasoning_content' in choice.message: - current_reasoning_content = choice.message.reasoning_content + if "reasoning_content" in choice.message: + current_reasoning_content = ( + choice.message.reasoning_content + ) if current_reasoning_content: - accumulated_data[choice_idx]['reasoning_content'] += current_reasoning_content + accumulated_data[choice_idx][ + "reasoning_content" + ] += current_reasoning_content # Always set the accumulated reasoning_content back if we # have any, even if current response doesn't have it - if accumulated_data[choice_idx]['reasoning_content']: - choice.message.reasoning_content = accumulated_data[choice_idx]['reasoning_content'] + if accumulated_data[choice_idx]["reasoning_content"]: + choice.message.reasoning_content = accumulated_data[ + choice_idx + ]["reasoning_content"] # Handle tool_calls accumulation - if 'tool_calls' in choice.message and choice.message.tool_calls: + if ( + "tool_calls" in choice.message + and choice.message.tool_calls + ): current_tool_calls = choice.message.tool_calls # For each current tool call, accumulate its arguments for current_call in current_tool_calls: - if isinstance(current_call, dict) and 'index' in current_call: - idx = current_call['index'] + if ( + isinstance(current_call, dict) + and "index" in current_call + ): + idx = current_call["index"] # Find existing accumulated call with same index existing_call = None - for acc_call in accumulated_data[choice_idx]['tool_calls']: - if (isinstance(acc_call, dict) and - acc_call.get('index') == idx): + for acc_call in accumulated_data[choice_idx][ + "tool_calls" + ]: + if ( + isinstance(acc_call, dict) + and acc_call.get("index") == idx + ): existing_call = acc_call break if existing_call: # Accumulate function fields from current call - if ('function' in current_call and - current_call['function']): - if 'function' not in existing_call: - existing_call['function'] = {} + if ( + "function" in current_call + and current_call["function"] + ): + if "function" not in existing_call: + existing_call["function"] = {} # Accumulate function.name - if 'name' in current_call['function']: - if 'name' not in existing_call['function']: - existing_call['function']['name'] = '' - existing_call['function']['name'] += current_call['function']['name'] + if "name" in current_call["function"]: + if ( + "name" + not in existing_call["function"] + ): + existing_call["function"][ + "name" + ] = "" + existing_call["function"][ + "name" + ] += current_call["function"]["name"] # Accumulate function.arguments - if 'arguments' in current_call['function']: - if 'arguments' not in existing_call['function']: - existing_call['function']['arguments'] = '' - existing_call['function']['arguments'] += current_call['function']['arguments'] + if "arguments" in current_call["function"]: + if ( + "arguments" + not in existing_call["function"] + ): + existing_call["function"][ + "arguments" + ] = "" + existing_call["function"][ + "arguments" + ] += current_call["function"][ + "arguments" + ] # Update other fields with latest values - existing_call.update({k: v for k, v in current_call.items() - if k != 'function' and v}) - if 'function' in current_call and current_call['function']: - existing_call['function'].update({k: v for k, v in current_call['function'].items() - if k not in ['arguments', 'name'] and v}) + existing_call.update( + { + k: v + for k, v in current_call.items() + if k != "function" and v + }, + ) + if ( + "function" in current_call + and current_call["function"] + ): + existing_call["function"].update( + { + k: v + for k, v in current_call[ + "function" + ].items() + if k not in ["arguments", "name"] + and v + }, + ) else: # Add new tool call - accumulated_data[choice_idx]['tool_calls'].append(dict(current_call)) + accumulated_data[choice_idx][ + "tool_calls" + ].append(dict(current_call)) # Update choice with accumulated tool_calls - choice.message.tool_calls = accumulated_data[choice_idx]['tool_calls'] - elif accumulated_data[choice_idx]['tool_calls']: + choice.message.tool_calls = accumulated_data[choice_idx][ + "tool_calls" + ] + elif accumulated_data[choice_idx]["tool_calls"]: # If current response has no tool_calls but we have # accumulated tool_calls, restore them - choice.message.tool_calls = accumulated_data[choice_idx]['tool_calls'] + choice.message.tool_calls = accumulated_data[choice_idx][ + "tool_calls" + ] # Restore role if we have it - if accumulated_data[choice_idx]['role'] and (not hasattr(choice.message, 'role') or not choice.message.role): - choice.message.role = accumulated_data[choice_idx]['role'] + if accumulated_data[choice_idx]["role"] and ( + not hasattr(choice.message, "role") + or not choice.message.role + ): + choice.message.role = accumulated_data[choice_idx]["role"] # Handle logprobs accumulation (only if logprobs exists) try: - if ('logprobs' in choice and choice.logprobs and - isinstance(choice.logprobs, dict) and 'content' in choice.logprobs): - current_logprobs_content = choice.logprobs['content'] - if current_logprobs_content and isinstance(current_logprobs_content, list): + if ( + "logprobs" in choice + and choice.logprobs + and isinstance(choice.logprobs, dict) + and "content" in choice.logprobs + ): + current_logprobs_content = choice.logprobs["content"] + if current_logprobs_content and isinstance( + current_logprobs_content, + list, + ): # Initialize logprobs content if not exists - if 'logprobs' not in accumulated_data[choice_idx]: - accumulated_data[choice_idx]['logprobs'] = {'content': []} - elif 'content' not in accumulated_data[choice_idx]['logprobs']: - accumulated_data[choice_idx]['logprobs']['content'] = [] + if "logprobs" not in accumulated_data[choice_idx]: + accumulated_data[choice_idx]["logprobs"] = { + "content": [], + } + elif ( + "content" + not in accumulated_data[choice_idx]["logprobs"] + ): + accumulated_data[choice_idx]["logprobs"][ + "content" + ] = [] # Extend the accumulated logprobs content array - accumulated_data[choice_idx]['logprobs']['content'].extend(current_logprobs_content) + accumulated_data[choice_idx]["logprobs"][ + "content" + ].extend(current_logprobs_content) except (KeyError, AttributeError, TypeError): - # logprobs field might not exist or be in unexpected format, safely skip + # logprobs field might not exist or be in unexpected format, safely skip # noqa: E501 # pylint: disable=line-too-long pass # Always set accumulated logprobs if we have any - if (accumulated_data[choice_idx]['logprobs']['content'] and - hasattr(choice, 'logprobs') and choice.logprobs): - choice.logprobs['content'] = accumulated_data[choice_idx][ - 'logprobs']['content'] + if ( + accumulated_data[choice_idx]["logprobs"]["content"] + and hasattr(choice, "logprobs") + and choice.logprobs + ): + choice.logprobs["content"] = accumulated_data[choice_idx][ + "logprobs" + ]["content"] # Handle finish_reason for n > 1 case - if (n > 1 and hasattr(choice, 'finish_reason') and - choice.finish_reason and - choice.finish_reason != 'null'): - accumulated_data[choice_idx]['finish_reason'] = \ - choice.finish_reason - accumulated_data[choice_idx]['finished'] = True + if ( + n > 1 + and hasattr(choice, "finish_reason") + and choice.finish_reason + and choice.finish_reason != "null" + ): + accumulated_data[choice_idx][ + "finish_reason" + ] = choice.finish_reason + accumulated_data[choice_idx]["finished"] = True # Handle n > 1 case: different strategies for different finish_reason if n > 1: # Count finished choices - finished_count = sum(1 for data in accumulated_data.values() - if isinstance(data, dict) and - data.get('finished', False)) + finished_count = sum( + 1 + for data in accumulated_data.values() + if isinstance(data, dict) and data.get("finished", False) + ) # Find all finished choices in current packet finished_choices_in_packet = [] for choice in choices: - if (hasattr(choice, 'finish_reason') and - choice.finish_reason and - choice.finish_reason != 'null'): - choice_idx = (choice.index if hasattr(choice, 'index') and - 'index' in choice else 0) + if ( + hasattr(choice, "finish_reason") + and choice.finish_reason + and choice.finish_reason != "null" + ): + choice_idx = ( + choice.index + if hasattr(choice, "index") and "index" in choice + else 0 + ) finish_reason = choice.finish_reason finished_choices_in_packet.append( - (choice_idx, finish_reason, choice)) + (choice_idx, finish_reason, choice), + ) # No finish_reason in current packet: return as is if not finished_choices_in_packet: @@ -277,55 +445,69 @@ def merge_single_response(parsed_response, accumulated_data, n=1): first_finish_reason = finished_choices_in_packet[0][1] # For stop: wait all choices, then merge into one result - if first_finish_reason == 'stop': + if first_finish_reason == "stop": if finished_count < n: # Hide finish_reason until all finished for choice in choices: - if (hasattr(choice, 'finish_reason') and - choice.finish_reason and - choice.finish_reason != 'null'): - choice.finish_reason = 'null' + if ( + hasattr(choice, "finish_reason") + and choice.finish_reason + and choice.finish_reason != "null" + ): + choice.finish_reason = "null" else: # All finished: merge all choices into one result for data in accumulated_data.values(): - if isinstance(data, dict) and 'all_choices_sent' in data: - data['all_choices_sent'] = True + if ( + isinstance(data, dict) + and "all_choices_sent" in data + ): + data["all_choices_sent"] = True # Return final result with all choices all_choices = [] # Sort by choice_idx to ensure correct order sorted_items = sorted( - [(idx, data) for idx, data in accumulated_data.items() - if isinstance(data, dict) and 'finished' in data], - key=lambda x: x[0] + [ + (idx, data) + for idx, data in accumulated_data.items() + if isinstance(data, dict) and "finished" in data + ], + key=lambda x: x[0], ) for choice_idx, data in sorted_items: # Create a new choice object final_choice_dict = { - 'index': choice_idx, - 'finish_reason': data['finish_reason'] + "index": choice_idx, + "finish_reason": data["finish_reason"], } # Create message message_dict = { - 'role': data['role'] if data['role'] else 'assistant' + "role": data["role"] + if data["role"] + else "assistant", } - if data['content']: - message_dict['content'] = ( - data['content'] if isinstance(data['content'], str) - else data['content']) - if data['reasoning_content']: - message_dict['reasoning_content'] = data['reasoning_content'] - if data['tool_calls']: - message_dict['tool_calls'] = data['tool_calls'] - - final_choice_dict['message'] = message_dict + if data["content"]: + message_dict["content"] = ( + data["content"] + if isinstance(data["content"], str) + else data["content"] + ) + if data["reasoning_content"]: + message_dict["reasoning_content"] = data[ + "reasoning_content" + ] + if data["tool_calls"]: + message_dict["tool_calls"] = data["tool_calls"] + + final_choice_dict["message"] = message_dict # Add logprobs if present - if data['logprobs']['content']: - final_choice_dict['logprobs'] = { - 'content': data['logprobs']['content'] + if data["logprobs"]["content"]: + final_choice_dict["logprobs"] = { + "content": data["logprobs"]["content"], } all_choices.append(final_choice_dict) @@ -334,10 +516,12 @@ def merge_single_response(parsed_response, accumulated_data, n=1): parsed_response.output.choices = all_choices # Aggregate usage from all choice indices - if 'usage_by_index' in accumulated_data and accumulated_data[ - 'usage_by_index']: + if ( + "usage_by_index" in accumulated_data + and accumulated_data["usage_by_index"] + ): aggregated_usage = {} - usage_by_idx = accumulated_data['usage_by_index'] + usage_by_idx = accumulated_data["usage_by_index"] # Sum output_tokens and recalculate total_tokens total_output_tokens = 0 @@ -345,42 +529,56 @@ def merge_single_response(parsed_response, accumulated_data, n=1): prompt_tokens_details = None for idx, usage in usage_by_idx.items(): - if 'output_tokens' in usage: - total_output_tokens += usage['output_tokens'] + if "output_tokens" in usage: + total_output_tokens += usage["output_tokens"] # input_tokens should be the same for all indices - if input_tokens is None and 'input_tokens' in usage: - input_tokens = usage['input_tokens'] + if ( + input_tokens is None + and "input_tokens" in usage + ): + input_tokens = usage["input_tokens"] # Keep prompt_tokens_details from any index # (should be same) - if (prompt_tokens_details is None and - 'prompt_tokens_details' in usage): + if ( + prompt_tokens_details is None + and "prompt_tokens_details" in usage + ): prompt_tokens_details = usage[ - 'prompt_tokens_details'] + "prompt_tokens_details" + ] # Build aggregated usage if input_tokens is not None: - aggregated_usage['input_tokens'] = input_tokens - aggregated_usage['output_tokens'] = total_output_tokens + aggregated_usage["input_tokens"] = input_tokens + aggregated_usage["output_tokens"] = total_output_tokens if input_tokens is not None: - aggregated_usage['total_tokens'] = ( - input_tokens + total_output_tokens) + aggregated_usage["total_tokens"] = ( + input_tokens + total_output_tokens + ) if prompt_tokens_details is not None: - aggregated_usage['prompt_tokens_details'] = ( - prompt_tokens_details) + aggregated_usage[ + "prompt_tokens_details" + ] = prompt_tokens_details # Update response usage with aggregated values parsed_response.usage = aggregated_usage else: - # For non-stop (e.g., tool_calls): output each choice separately + # For non-stop (e.g., tool_calls): output each choice separately # noqa: E501 responses_to_yield = [] - for choice_idx, finish_reason, choice in finished_choices_in_packet: + for ( + choice_idx, + finish_reason, + choice, + ) in finished_choices_in_packet: current_data = accumulated_data.get(choice_idx) - if (current_data is None or - current_data.get('all_choices_sent', False)): + if current_data is None or current_data.get( + "all_choices_sent", + False, + ): continue - current_data['all_choices_sent'] = True + current_data["all_choices_sent"] = True # Create a new response for this choice if responses_to_yield: @@ -397,18 +595,23 @@ def merge_single_response(parsed_response, accumulated_data, n=1): new_response.output.choices = [choice_copy] # Update usage with this choice's output tokens - if (new_response.usage and - 'usage_by_index' in accumulated_data and - choice_idx in accumulated_data['usage_by_index']): - current_usage = accumulated_data['usage_by_index'][ - choice_idx] - if 'output_tokens' in current_usage: - new_response.usage['output_tokens'] = ( - current_usage['output_tokens']) - if 'input_tokens' in current_usage: - new_response.usage['total_tokens'] = ( - current_usage['input_tokens'] + - current_usage['output_tokens']) + if ( + new_response.usage + and "usage_by_index" in accumulated_data + and choice_idx in accumulated_data["usage_by_index"] + ): + current_usage = accumulated_data["usage_by_index"][ + choice_idx + ] + if "output_tokens" in current_usage: + new_response.usage[ + "output_tokens" + ] = current_usage["output_tokens"] + if "input_tokens" in current_usage: + new_response.usage["total_tokens"] = ( + current_usage["input_tokens"] + + current_usage["output_tokens"] + ) responses_to_yield.append(new_response) @@ -421,7 +624,12 @@ def merge_single_response(parsed_response, accumulated_data, n=1): return True -def merge_multimodal_single_response(parsed_response, accumulated_data, n=1): +# pylint: disable=R1702,too-many-branches,too-many-statements +def merge_multimodal_single_response( # noqa: E501 + parsed_response, + accumulated_data, + n=1, +): """Merge a single response chunk with accumulated data. Args: @@ -434,53 +642,74 @@ def merge_multimodal_single_response(parsed_response, accumulated_data, n=1): """ # Check if all choices have been sent (for n > 1 case) if n > 1 and accumulated_data: - all_sent = any(data.get('all_choices_sent', False) - for data in accumulated_data.values()) + all_sent = any( + data.get("all_choices_sent", False) + for data in accumulated_data.values() + ) if all_sent: return False # Track usage for each choice index when n > 1 # Each streaming packet contains usage info for one specific choice - if (n > 1 and parsed_response.usage and - parsed_response.output and parsed_response.output.choices and - len(parsed_response.output.choices) > 0): - if 'usage_by_index' not in accumulated_data: - accumulated_data['usage_by_index'] = {} - - # Get the choice index from the first (and typically only) choice in this packet + if ( + n > 1 + and parsed_response.usage + and parsed_response.output + and parsed_response.output.choices + and len(parsed_response.output.choices) > 0 + ): + if "usage_by_index" not in accumulated_data: + accumulated_data["usage_by_index"] = {} + + # Get the choice index from the first (and typically only) choice in this packet # noqa: E501 # pylint: disable=line-too-long try: first_choice = parsed_response.output.choices[0] - choice_idx = first_choice.index if hasattr( - first_choice, 'index') and 'index' in first_choice else 0 + choice_idx = ( + first_choice.index + if hasattr( + first_choice, + "index", + ) + and "index" in first_choice + else 0 + ) # Store only output_tokens for this choice index - if 'output_tokens' in parsed_response.usage: - accumulated_data['usage_by_index'][choice_idx] = dict( - parsed_response.usage) + if "output_tokens" in parsed_response.usage: + accumulated_data["usage_by_index"][choice_idx] = dict( + parsed_response.usage, + ) except (KeyError, AttributeError, IndexError): pass # Handle output.text accumulation when choices is null - if (parsed_response.output and - hasattr(parsed_response.output, 'text') and - (not parsed_response.output.choices or parsed_response.output.choices is None)): + if ( + parsed_response.output + and hasattr(parsed_response.output, "text") + and ( + not parsed_response.output.choices + or parsed_response.output.choices is None + ) + ): choice_idx = 0 if choice_idx not in accumulated_data: accumulated_data[choice_idx] = { - 'content': '', - 'reasoning_content': '', - 'tool_calls': [], - 'logprobs': {'content': []}, - 'finished': False, - 'finish_reason': None, - 'all_choices_sent': False, - 'role': None + "content": "", + "reasoning_content": "", + "tool_calls": [], + "logprobs": {"content": []}, + "finished": False, + "finish_reason": None, + "all_choices_sent": False, + "role": None, } # Accumulate text if not empty if parsed_response.output.text: - accumulated_data[choice_idx]['content'] += parsed_response.output.text + accumulated_data[choice_idx][ + "content" + ] += parsed_response.output.text # Always set accumulated content back to response - parsed_response.output.text = accumulated_data[choice_idx]['content'] + parsed_response.output.text = accumulated_data[choice_idx]["content"] return True # Process each choice in the choices array @@ -494,194 +723,325 @@ def merge_multimodal_single_response(parsed_response, accumulated_data, n=1): for choice_enum_idx, choice in enumerate(choices): # Use choice.index if available, otherwise use enumerate index try: - choice_idx = choice.index if hasattr(choice, 'index') and 'index' in choice else choice_enum_idx + choice_idx = ( + choice.index + if hasattr(choice, "index") and "index" in choice + else choice_enum_idx + ) except (KeyError, AttributeError): choice_idx = choice_enum_idx # Initialize accumulated data for this choice if not exists if choice_idx not in accumulated_data: accumulated_data[choice_idx] = { - 'content': '', - 'reasoning_content': '', - 'tool_calls': [], - 'logprobs': {'content': []}, - 'finished': False, - 'finish_reason': None, - 'all_choices_sent': False, - 'role': None + "content": "", + "reasoning_content": "", + "tool_calls": [], + "logprobs": {"content": []}, + "finished": False, + "finish_reason": None, + "all_choices_sent": False, + "role": None, } # Handle message field - create if null if not choice.message: # Create message object with accumulated data choice.message = { - 'role': accumulated_data[choice_idx]['role'] if accumulated_data[choice_idx]['role'] else 'assistant', - 'content': accumulated_data[choice_idx]['content'] + "role": accumulated_data[choice_idx]["role"] + if accumulated_data[choice_idx]["role"] + else "assistant", + "content": accumulated_data[choice_idx]["content"], } - if accumulated_data[choice_idx]['reasoning_content']: - choice.message['reasoning_content'] = accumulated_data[choice_idx]['reasoning_content'] - if accumulated_data[choice_idx]['tool_calls']: - choice.message['tool_calls'] = accumulated_data[choice_idx]['tool_calls'] + if accumulated_data[choice_idx]["reasoning_content"]: + choice.message["reasoning_content"] = accumulated_data[ + choice_idx + ]["reasoning_content"] + if accumulated_data[choice_idx]["tool_calls"]: + choice.message["tool_calls"] = accumulated_data[ + choice_idx + ]["tool_calls"] else: # Save role if present - if hasattr(choice.message, 'role') and choice.message.role: - accumulated_data[choice_idx]['role'] = choice.message.role + if hasattr(choice.message, "role") and choice.message.role: + accumulated_data[choice_idx]["role"] = choice.message.role # Handle content accumulation - if 'content' in choice.message: + if "content" in choice.message: current_content = choice.message.content # Check if content is multimodal format if isinstance(current_content, list): # Handle multimodal content (array format) - # Initialize accumulated content as array if not already - if not isinstance(accumulated_data[choice_idx]['content'], list): - accumulated_data[choice_idx]['content'] = [] + # Initialize accumulated content as array if not already # noqa: E501 + if not isinstance( + accumulated_data[choice_idx]["content"], + list, + ): + accumulated_data[choice_idx]["content"] = [] # Only process if current_content is not empty if current_content: - # Ensure accumulated content list has enough elements - while len(accumulated_data[choice_idx]['content']) < len(current_content): - accumulated_data[choice_idx]['content'].append({'text': ''}) + # Ensure accumulated content list has enough elements # noqa: E501 + while len( + accumulated_data[choice_idx]["content"], + ) < len(current_content): + accumulated_data[choice_idx]["content"].append( + {"text": ""}, + ) # Merge each content element - for content_idx, content_item in enumerate(current_content): - if isinstance(content_item, dict) and 'text' in content_item: - if content_item['text']: + for content_idx, content_item in enumerate( + current_content, + ): + if ( + isinstance(content_item, dict) + and "text" in content_item + ): + if content_item["text"]: # Accumulate text content - accumulated_data[choice_idx]['content'][content_idx]['text'] += content_item['text'] + accumulated_data[choice_idx][ + "content" + ][content_idx]["text"] += content_item[ + "text" + ] # Always set accumulated content back to response - choice.message.content = accumulated_data[choice_idx]['content'] + choice.message.content = accumulated_data[choice_idx][ + "content" + ] elif current_content: # Handle regular content (string format) # Initialize accumulated content as string - if isinstance(accumulated_data[choice_idx]['content'], list): - accumulated_data[choice_idx]['content'] = '' + if isinstance( + accumulated_data[choice_idx]["content"], + list, + ): + accumulated_data[choice_idx]["content"] = "" # Accumulate content if not empty - accumulated_data[choice_idx]['content'] += current_content + accumulated_data[choice_idx][ + "content" + ] += current_content # Set accumulated content back to response - choice.message.content = accumulated_data[choice_idx]['content'] - elif not current_content and accumulated_data[choice_idx]['content']: - # Current content is empty but we have accumulated content, restore it - choice.message.content = accumulated_data[choice_idx]['content'] + choice.message.content = accumulated_data[choice_idx][ + "content" + ] + elif ( + not current_content + and accumulated_data[choice_idx]["content"] + ): + # Current content is empty but we have accumulated content, restore it # noqa: E501 # pylint: disable=line-too-long + choice.message.content = accumulated_data[choice_idx][ + "content" + ] # Handle reasoning_content accumulation - if 'reasoning_content' in choice.message: - current_reasoning_content = choice.message.reasoning_content + if "reasoning_content" in choice.message: + current_reasoning_content = ( + choice.message.reasoning_content + ) if current_reasoning_content: - accumulated_data[choice_idx]['reasoning_content'] += current_reasoning_content + accumulated_data[choice_idx][ + "reasoning_content" + ] += current_reasoning_content # Always set the accumulated reasoning_content back if we # have any, even if current response doesn't have it - if accumulated_data[choice_idx]['reasoning_content']: - choice.message.reasoning_content = accumulated_data[choice_idx]['reasoning_content'] + if accumulated_data[choice_idx]["reasoning_content"]: + choice.message.reasoning_content = accumulated_data[ + choice_idx + ]["reasoning_content"] # Handle tool_calls accumulation - if 'tool_calls' in choice.message and choice.message.tool_calls: + if ( + "tool_calls" in choice.message + and choice.message.tool_calls + ): current_tool_calls = choice.message.tool_calls # For each current tool call, accumulate its arguments for current_call in current_tool_calls: - if isinstance(current_call, dict) and 'index' in current_call: - idx = current_call['index'] + if ( + isinstance(current_call, dict) + and "index" in current_call + ): + idx = current_call["index"] # Find existing accumulated call with same index existing_call = None - for acc_call in accumulated_data[choice_idx]['tool_calls']: - if (isinstance(acc_call, dict) and - acc_call.get('index') == idx): + for acc_call in accumulated_data[choice_idx][ + "tool_calls" + ]: + if ( + isinstance(acc_call, dict) + and acc_call.get("index") == idx + ): existing_call = acc_call break if existing_call: # Accumulate function fields from current call - if ('function' in current_call and - current_call['function']): - if 'function' not in existing_call: - existing_call['function'] = {} + if ( + "function" in current_call + and current_call["function"] + ): + if "function" not in existing_call: + existing_call["function"] = {} # Accumulate function.name - if 'name' in current_call['function']: - if 'name' not in existing_call['function']: - existing_call['function']['name'] = '' - existing_call['function']['name'] += current_call['function']['name'] + if "name" in current_call["function"]: + if ( + "name" + not in existing_call["function"] + ): + existing_call["function"][ + "name" + ] = "" + existing_call["function"][ + "name" + ] += current_call["function"]["name"] # Accumulate function.arguments - if 'arguments' in current_call['function']: - if 'arguments' not in existing_call['function']: - existing_call['function']['arguments'] = '' - existing_call['function']['arguments'] += current_call['function']['arguments'] + if "arguments" in current_call["function"]: + if ( + "arguments" + not in existing_call["function"] + ): + existing_call["function"][ + "arguments" + ] = "" + existing_call["function"][ + "arguments" + ] += current_call["function"][ + "arguments" + ] # Update other fields with latest values - existing_call.update({k: v for k, v in current_call.items() - if k != 'function' and v}) - if 'function' in current_call and current_call['function']: - existing_call['function'].update({k: v for k, v in current_call['function'].items() - if k not in ['arguments', 'name'] and v}) + existing_call.update( + { + k: v + for k, v in current_call.items() + if k != "function" and v + }, + ) + if ( + "function" in current_call + and current_call["function"] + ): + existing_call["function"].update( + { + k: v + for k, v in current_call[ + "function" + ].items() + if k not in ["arguments", "name"] + and v + }, + ) else: # Add new tool call - accumulated_data[choice_idx]['tool_calls'].append(dict(current_call)) + accumulated_data[choice_idx][ + "tool_calls" + ].append(dict(current_call)) # Update choice with accumulated tool_calls - choice.message.tool_calls = accumulated_data[choice_idx]['tool_calls'] - elif accumulated_data[choice_idx]['tool_calls']: - # If current response has no tool_calls but we have accumulated tool_calls, restore them - choice.message.tool_calls = accumulated_data[choice_idx]['tool_calls'] + choice.message.tool_calls = accumulated_data[choice_idx][ + "tool_calls" + ] + elif accumulated_data[choice_idx]["tool_calls"]: + # If current response has no tool_calls but we have accumulated tool_calls, restore them # noqa: E501 # pylint: disable=line-too-long + choice.message.tool_calls = accumulated_data[choice_idx][ + "tool_calls" + ] # Restore role if we have it - if accumulated_data[choice_idx]['role'] and (not hasattr(choice.message, 'role') or not choice.message.role): - choice.message.role = accumulated_data[choice_idx]['role'] + if accumulated_data[choice_idx]["role"] and ( + not hasattr(choice.message, "role") + or not choice.message.role + ): + choice.message.role = accumulated_data[choice_idx]["role"] # Handle logprobs accumulation (only if logprobs exists) try: - if ('logprobs' in choice and choice.logprobs and - isinstance(choice.logprobs, dict) and 'content' in choice.logprobs): - current_logprobs_content = choice.logprobs['content'] - if current_logprobs_content and isinstance(current_logprobs_content, list): + if ( + "logprobs" in choice + and choice.logprobs + and isinstance(choice.logprobs, dict) + and "content" in choice.logprobs + ): + current_logprobs_content = choice.logprobs["content"] + if current_logprobs_content and isinstance( + current_logprobs_content, + list, + ): # Initialize logprobs content if not exists - if 'logprobs' not in accumulated_data[choice_idx]: - accumulated_data[choice_idx]['logprobs'] = {'content': []} - elif 'content' not in accumulated_data[choice_idx]['logprobs']: - accumulated_data[choice_idx]['logprobs']['content'] = [] + if "logprobs" not in accumulated_data[choice_idx]: + accumulated_data[choice_idx]["logprobs"] = { + "content": [], + } + elif ( + "content" + not in accumulated_data[choice_idx]["logprobs"] + ): + accumulated_data[choice_idx]["logprobs"][ + "content" + ] = [] # Extend the accumulated logprobs content array - accumulated_data[choice_idx]['logprobs']['content'].extend(current_logprobs_content) + accumulated_data[choice_idx]["logprobs"][ + "content" + ].extend(current_logprobs_content) except (KeyError, AttributeError, TypeError): - # logprobs field might not exist or be in unexpected format, safely skip + # logprobs field might not exist or be in unexpected format, safely skip # noqa: E501 # pylint: disable=line-too-long pass # Always set accumulated logprobs if we have any - if (accumulated_data[choice_idx]['logprobs']['content'] and - hasattr(choice, 'logprobs') and choice.logprobs): - choice.logprobs['content'] = accumulated_data[choice_idx][ - 'logprobs']['content'] + if ( + accumulated_data[choice_idx]["logprobs"]["content"] + and hasattr(choice, "logprobs") + and choice.logprobs + ): + choice.logprobs["content"] = accumulated_data[choice_idx][ + "logprobs" + ]["content"] # Handle finish_reason for n > 1 case - if (n > 1 and hasattr(choice, 'finish_reason') and - choice.finish_reason and - choice.finish_reason != 'null'): - accumulated_data[choice_idx]['finish_reason'] = \ - choice.finish_reason - accumulated_data[choice_idx]['finished'] = True + if ( + n > 1 + and hasattr(choice, "finish_reason") + and choice.finish_reason + and choice.finish_reason != "null" + ): + accumulated_data[choice_idx][ + "finish_reason" + ] = choice.finish_reason + accumulated_data[choice_idx]["finished"] = True # Handle n > 1 case: different strategies for different # finish_reason if n > 1: # Count finished choices - finished_count = sum(1 for data in accumulated_data.values() - if isinstance(data, dict) and - data.get('finished', False)) + finished_count = sum( + 1 + for data in accumulated_data.values() + if isinstance(data, dict) and data.get("finished", False) + ) # Find all finished choices in current packet finished_choices_in_packet = [] for choice in choices: - if (hasattr(choice, 'finish_reason') and - choice.finish_reason and - choice.finish_reason != 'null'): - choice_idx = (choice.index if hasattr(choice, 'index') and - 'index' in choice else 0) + if ( + hasattr(choice, "finish_reason") + and choice.finish_reason + and choice.finish_reason != "null" + ): + choice_idx = ( + choice.index + if hasattr(choice, "index") and "index" in choice + else 0 + ) finish_reason = choice.finish_reason finished_choices_in_packet.append( - (choice_idx, finish_reason, choice)) + (choice_idx, finish_reason, choice), + ) # No finish_reason in current packet: return as is if not finished_choices_in_packet: @@ -691,57 +1051,72 @@ def merge_multimodal_single_response(parsed_response, accumulated_data, n=1): first_finish_reason = finished_choices_in_packet[0][1] # For stop: wait all choices, then merge into one result - if first_finish_reason == 'stop': + if first_finish_reason == "stop": if finished_count < n: # Hide finish_reason until all finished for choice in choices: - if (hasattr(choice, 'finish_reason') and - choice.finish_reason and - choice.finish_reason != 'null'): - choice.finish_reason = 'null' + if ( + hasattr(choice, "finish_reason") + and choice.finish_reason + and choice.finish_reason != "null" + ): + choice.finish_reason = "null" else: # All finished: merge all choices into one result for data in accumulated_data.values(): - if isinstance(data, dict) and 'all_choices_sent' in data: - data['all_choices_sent'] = True + if ( + isinstance(data, dict) + and "all_choices_sent" in data + ): + data["all_choices_sent"] = True # Return final result with all choices all_choices = [] # Sort by choice_idx to ensure correct order sorted_items = sorted( - [(idx, data) for idx, data in accumulated_data.items() - if isinstance(data, dict) and 'finished' in data], - key=lambda x: x[0] + [ + (idx, data) + for idx, data in accumulated_data.items() + if isinstance(data, dict) and "finished" in data + ], + key=lambda x: x[0], ) for choice_idx, data in sorted_items: # Create a new choice object final_choice_dict = { - 'index': choice_idx, - 'finish_reason': data['finish_reason'] + "index": choice_idx, + "finish_reason": data["finish_reason"], } # Create message message_dict = { - 'role': data['role'] if data['role'] else 'assistant' + "role": data["role"] + if data["role"] + else "assistant", } - if data['content']: - message_dict['content'] = ( - data['content'] if isinstance(data['content'], - str) - else data['content']) - if data['reasoning_content']: - message_dict['reasoning_content'] = ( - data['reasoning_content']) - if data['tool_calls']: - message_dict['tool_calls'] = data['tool_calls'] - - final_choice_dict['message'] = message_dict + if data["content"]: + message_dict["content"] = ( + data["content"] + if isinstance( + data["content"], + str, + ) + else data["content"] + ) + if data["reasoning_content"]: + message_dict["reasoning_content"] = data[ + "reasoning_content" + ] + if data["tool_calls"]: + message_dict["tool_calls"] = data["tool_calls"] + + final_choice_dict["message"] = message_dict # Add logprobs if present - if data['logprobs']['content']: - final_choice_dict['logprobs'] = { - 'content': data['logprobs']['content'] + if data["logprobs"]["content"]: + final_choice_dict["logprobs"] = { + "content": data["logprobs"]["content"], } all_choices.append(final_choice_dict) @@ -750,10 +1125,12 @@ def merge_multimodal_single_response(parsed_response, accumulated_data, n=1): parsed_response.output.choices = all_choices # Aggregate usage from all choice indices - if 'usage_by_index' in accumulated_data and accumulated_data[ - 'usage_by_index']: + if ( + "usage_by_index" in accumulated_data + and accumulated_data["usage_by_index"] + ): aggregated_usage = {} - usage_by_idx = accumulated_data['usage_by_index'] + usage_by_idx = accumulated_data["usage_by_index"] # Sum output_tokens and recalculate total_tokens total_output_tokens = 0 @@ -761,28 +1138,36 @@ def merge_multimodal_single_response(parsed_response, accumulated_data, n=1): prompt_tokens_details = None for idx, usage in usage_by_idx.items(): - if 'output_tokens' in usage: - total_output_tokens += usage['output_tokens'] + if "output_tokens" in usage: + total_output_tokens += usage["output_tokens"] # input_tokens should be the same for all indices - if input_tokens is None and 'input_tokens' in usage: - input_tokens = usage['input_tokens'] + if ( + input_tokens is None + and "input_tokens" in usage + ): + input_tokens = usage["input_tokens"] # Keep prompt_tokens_details from any index # (should be same) - if (prompt_tokens_details is None and - 'prompt_tokens_details' in usage): + if ( + prompt_tokens_details is None + and "prompt_tokens_details" in usage + ): prompt_tokens_details = usage[ - 'prompt_tokens_details'] + "prompt_tokens_details" + ] # Build aggregated usage if input_tokens is not None: - aggregated_usage['input_tokens'] = input_tokens - aggregated_usage['output_tokens'] = total_output_tokens + aggregated_usage["input_tokens"] = input_tokens + aggregated_usage["output_tokens"] = total_output_tokens if input_tokens is not None: - aggregated_usage['total_tokens'] = ( - input_tokens + total_output_tokens) + aggregated_usage["total_tokens"] = ( + input_tokens + total_output_tokens + ) if prompt_tokens_details is not None: - aggregated_usage['prompt_tokens_details'] = ( - prompt_tokens_details) + aggregated_usage[ + "prompt_tokens_details" + ] = prompt_tokens_details # Update response usage with aggregated values parsed_response.usage = aggregated_usage @@ -791,13 +1176,19 @@ def merge_multimodal_single_response(parsed_response, accumulated_data, n=1): # separately responses_to_yield = [] - for choice_idx, finish_reason, choice in finished_choices_in_packet: + for ( + choice_idx, + finish_reason, + choice, + ) in finished_choices_in_packet: current_data = accumulated_data.get(choice_idx) - if (current_data is None or - current_data.get('all_choices_sent', False)): + if current_data is None or current_data.get( + "all_choices_sent", + False, + ): continue - current_data['all_choices_sent'] = True + current_data["all_choices_sent"] = True # Create a new response for this choice if responses_to_yield: @@ -814,18 +1205,23 @@ def merge_multimodal_single_response(parsed_response, accumulated_data, n=1): new_response.output.choices = [choice_copy] # Update usage with this choice's output tokens - if (new_response.usage and - 'usage_by_index' in accumulated_data and - choice_idx in accumulated_data['usage_by_index']): - current_usage = accumulated_data['usage_by_index'][ - choice_idx] - if 'output_tokens' in current_usage: - new_response.usage['output_tokens'] = ( - current_usage['output_tokens']) - if 'input_tokens' in current_usage: - new_response.usage['total_tokens'] = ( - current_usage['input_tokens'] + - current_usage['output_tokens']) + if ( + new_response.usage + and "usage_by_index" in accumulated_data + and choice_idx in accumulated_data["usage_by_index"] + ): + current_usage = accumulated_data["usage_by_index"][ + choice_idx + ] + if "output_tokens" in current_usage: + new_response.usage[ + "output_tokens" + ] = current_usage["output_tokens"] + if "input_tokens" in current_usage: + new_response.usage["total_tokens"] = ( + current_usage["input_tokens"] + + current_usage["output_tokens"] + ) responses_to_yield.append(new_response) diff --git a/dashscope/utils/oss_utils.py b/dashscope/utils/oss_utils.py index a99d993..0cf37b3 100644 --- a/dashscope/utils/oss_utils.py +++ b/dashscope/utils/oss_utils.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import mimetypes @@ -5,7 +6,6 @@ from datetime import datetime from http import HTTPStatus from time import mktime -from typing import List from urllib.parse import unquote_plus, urlparse from wsgiref.handlers import format_date_time @@ -20,23 +20,25 @@ class OssUtils(GetMixin): - SUB_PATH = 'uploads' + SUB_PATH = "uploads" @classmethod def _decode_response_error(cls, response: requests.Response): - if 'application/json' in response.headers.get('content-type', ''): + if "application/json" in response.headers.get("content-type", ""): message = response.json() else: - message = response.content.decode('utf-8') + message = response.content.decode("utf-8") return message @classmethod - def upload(cls, - model: str, - file_path: str, - api_key: str = None, - upload_certificate: dict = None, - **kwargs): + def upload( + cls, + model: str, + file_path: str, + api_key: str = None, + upload_certificate: dict = None, + **kwargs, + ): """Upload file for model fine-tune or other tasks. Args: @@ -52,53 +54,64 @@ def upload(cls, OSS URL and upload_certificate is the certificate used """ if upload_certificate is None: - upload_info = cls.get_upload_certificate(model=model, - api_key=api_key, - **kwargs) + upload_info = cls.get_upload_certificate( + model=model, + api_key=api_key, + **kwargs, + ) if upload_info.status_code != HTTPStatus.OK: raise UploadFileException( - 'Get upload certificate failed, code: %s, message: %s' % - (upload_info.code, upload_info.message)) + f"Get upload certificate failed, code: " + f"{upload_info.code}, message: {upload_info.message}", + ) upload_info = upload_info.output else: upload_info = upload_certificate headers = {} - headers = {'user-agent': get_user_agent()} - headers['Accept'] = 'application/json' - headers['Date'] = format_date_time(mktime(datetime.now().timetuple())) + headers = {"user-agent": get_user_agent()} + headers["Accept"] = "application/json" + headers["Date"] = format_date_time(mktime(datetime.now().timetuple())) form_data = {} - form_data['OSSAccessKeyId'] = upload_info['oss_access_key_id'] - form_data['Signature'] = upload_info['signature'] - form_data['policy'] = upload_info['policy'] - form_data['key'] = upload_info['upload_dir'] + \ - '/' + os.path.basename(file_path) - form_data['x-oss-object-acl'] = upload_info['x_oss_object_acl'] - form_data['x-oss-forbid-overwrite'] = upload_info[ - 'x_oss_forbid_overwrite'] - form_data['success_action_status'] = '200' - form_data['x-oss-content-type'] = mimetypes.guess_type(file_path)[0] - url = upload_info['upload_host'] - files = {'file': open(file_path, 'rb')} + form_data["OSSAccessKeyId"] = upload_info["oss_access_key_id"] + form_data["Signature"] = upload_info["signature"] + form_data["policy"] = upload_info["policy"] + form_data["key"] = ( + upload_info["upload_dir"] + "/" + os.path.basename(file_path) + ) + form_data["x-oss-object-acl"] = upload_info["x_oss_object_acl"] + form_data["x-oss-forbid-overwrite"] = upload_info[ + "x_oss_forbid_overwrite" + ] + form_data["success_action_status"] = "200" + form_data["x-oss-content-type"] = mimetypes.guess_type(file_path)[0] + url = upload_info["upload_host"] + # pylint: disable=consider-using-with + files = {"file": open(file_path, "rb")} with requests.Session() as session: - response = session.post(url, - files=files, - data=form_data, - headers=headers, - timeout=3600) + response = session.post( + url, + files=files, + data=form_data, + headers=headers, + timeout=3600, + ) if response.status_code == HTTPStatus.OK: - return 'oss://' + form_data['key'], upload_info + return "oss://" + form_data["key"], upload_info else: msg = ( - 'Uploading file: %s to oss failed, error: %s' % - (file_path, cls._decode_response_error(response=response))) + f"Uploading file: {file_path} to oss failed, " + f"error: {cls._decode_response_error(response=response)}" + ) logger.error(msg) raise UploadFileException(msg) @classmethod - def get_upload_certificate(cls, - model: str, - api_key: str = None, - **kwargs) -> DashScopeAPIResponse: + def get_upload_certificate( + cls, + model: str, + api_key: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Get a oss upload certificate. Args: @@ -107,13 +120,18 @@ def get_upload_certificate(cls, Returns: DashScopeAPIResponse: The job info """ - params = {'action': 'getPolicy'} - params['model'] = model - return super().get(None, api_key, params=params, **kwargs) + params = {"action": "getPolicy"} + params["model"] = model + # type: ignore + return super().get(None, api_key, params=params, **kwargs) # type: ignore[return-value] # pylint: disable=line-too-long # noqa: E501 -def upload_file(model: str, upload_path: str, api_key: str, - upload_certificate: dict = None): +def upload_file( + model: str, + upload_path: str, + api_key: str, + upload_certificate: dict = None, +): if upload_path.startswith(FILE_PATH_SCHEMA): parse_result = urlparse(upload_path) if parse_result.netloc: @@ -121,21 +139,28 @@ def upload_file(model: str, upload_path: str, api_key: str, else: file_path = unquote_plus(parse_result.path) if os.path.exists(file_path): - file_url, _ = OssUtils.upload(model=model, - file_path=file_path, - api_key=api_key, - upload_certificate=upload_certificate) + file_url, _ = OssUtils.upload( + model=model, + file_path=file_path, + api_key=api_key, + upload_certificate=upload_certificate, + ) if file_url is None: - raise UploadFileException('Uploading file: %s failed' % - upload_path) + raise UploadFileException( + f"Uploading file: {upload_path} failed", + ) return file_url else: - raise InvalidInput('The file: %s is not exists!' % file_path) + raise InvalidInput(f"The file: {file_path} is not exists!") return None -def check_and_upload_local(model: str, content: str, api_key: str, - upload_certificate: dict = None): +def check_and_upload_local( + model: str, + content: str, + api_key: str, + upload_certificate: dict = None, +): """Check the content is local file path, upload and return the url Args: @@ -162,34 +187,44 @@ def check_and_upload_local(model: str, content: str, api_key: str, else: file_path = unquote_plus(parse_result.path) if os.path.isfile(file_path): - file_url, cert = OssUtils.upload(model=model, - file_path=file_path, - api_key=api_key, - upload_certificate=upload_certificate) + file_url, cert = OssUtils.upload( + model=model, + file_path=file_path, + api_key=api_key, + upload_certificate=upload_certificate, + ) if file_url is None: - raise UploadFileException('Uploading file: %s failed' % - content) + raise UploadFileException( + f"Uploading file: {content} failed", + ) return True, file_url, cert else: - raise InvalidInput('The file: %s is not exists!' % file_path) - elif content.startswith('oss://'): + raise InvalidInput(f"The file: {file_path} is not exists!") + elif content.startswith("oss://"): return True, content, upload_certificate - elif not content.startswith('http'): + elif not content.startswith("http"): content = os.path.expanduser(content) if os.path.isfile(content): - file_url, cert = OssUtils.upload(model=model, - file_path=content, - api_key=api_key, - upload_certificate=upload_certificate) + file_url, cert = OssUtils.upload( + model=model, + file_path=content, + api_key=api_key, + upload_certificate=upload_certificate, + ) if file_url is None: - raise UploadFileException('Uploading file: %s failed' % - content) + raise UploadFileException( + f"Uploading file: {content} failed", + ) return True, file_url, cert return False, content, upload_certificate -def check_and_upload(model, elem: dict, api_key, - upload_certificate: dict = None): +def check_and_upload( + model, + elem: dict, + api_key, + upload_certificate: dict = None, +): """Check and upload files in element. Args: @@ -211,10 +246,18 @@ def check_and_upload(model, elem: dict, api_key, is_list = isinstance(content, list) contents = content if is_list else [content] - if key in ['image', 'video', 'audio', 'text']: + if key in ["image", "video", "audio", "text"]: for i, content in enumerate(contents): - is_upload, file_url, obtained_certificate = check_and_upload_local( - model, content, api_key, obtained_certificate) + ( + is_upload, + file_url, + obtained_certificate, + ) = check_and_upload_local( + model, + content, + api_key, + obtained_certificate, + ) if is_upload: contents[i] = file_url has_upload = True @@ -223,8 +266,12 @@ def check_and_upload(model, elem: dict, api_key, return has_upload, obtained_certificate -def preprocess_message_element(model: str, elem: dict, api_key: str, - upload_certificate: dict = None): +def preprocess_message_element( + model: str, + elem: dict, + api_key: str, + upload_certificate: dict = None, +): """Preprocess message element and upload files if needed. Args: @@ -238,6 +285,10 @@ def preprocess_message_element(model: str, elem: dict, api_key: str, indicating if any file was uploaded, and upload_certificate is the certificate (newly obtained or passed in) """ - is_upload, cert = check_and_upload(model, elem, api_key, - upload_certificate) + is_upload, cert = check_and_upload( + model, + elem, + api_key, + upload_certificate, + ) return is_upload, cert diff --git a/dashscope/utils/param_utils.py b/dashscope/utils/param_utils.py index 007ad20..f6eaefe 100644 --- a/dashscope/utils/param_utils.py +++ b/dashscope/utils/param_utils.py @@ -1,16 +1,18 @@ +# -*- coding: utf-8 -*- + class ParamUtil: @staticmethod def should_modify_incremental_output(model_name: str) -> bool: """ - Determine if increment_output parameter needs to be modified based on + Determine if increment_output parameter needs to be modified based on model name. Args: model_name (str): The name of the model to check Returns: - bool: False if model contains 'tts', 'omni', or + bool: False if model contains 'tts', 'omni', or 'qwen-deep-research', True otherwise """ if not isinstance(model_name, str): @@ -19,11 +21,11 @@ def should_modify_incremental_output(model_name: str) -> bool: model_name_lower = model_name.lower() # Check for conditions that return False - if 'tts' in model_name_lower: + if "tts" in model_name_lower: return False - if 'omni' in model_name_lower: + if "omni" in model_name_lower: return False - if 'qwen-deep-research' in model_name_lower: + if "qwen-deep-research" in model_name_lower: return False - return True \ No newline at end of file + return True diff --git a/dashscope/version.py b/dashscope/version.py index 3f80fd0..5b02085 100644 --- a/dashscope/version.py +++ b/dashscope/version.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -__version__ = '1.25.8' +__version__ = "1.25.8" diff --git a/docs/source/conf.py b/docs/source/conf.py index 381772a..4029bf4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full @@ -15,19 +16,19 @@ # import sphinx_book_theme -sys.path.insert(0, os.path.abspath('../../')) +sys.path.insert(0, os.path.abspath("../../")) # -- Project information ----------------------------------------------------- -project = 'dashscope' -copyright = '2023-2024, Alibaba DashScope' -author = 'dashscope Authors' -version_file = '../../dashscope/version.py' +project = "dashscope" +copyright = "2023-2024, Alibaba DashScope" +author = "dashscope Authors" +version_file = "../../dashscope/version.py" def get_version(): - with open(version_file, 'r', encoding='utf-8') as f: - exec(compile(f.read(), version_file, 'exec')) - return locals()['__version__'] + with open(version_file, "r", encoding="utf-8") as f: + exec(compile(f.read(), version_file, "exec")) + return locals()["__version__"] # The full version, including alpha/beta/rc tags @@ -40,19 +41,19 @@ def get_version(): # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.napoleon', - 'sphinx.ext.autosummary', - 'sphinx.ext.autodoc', - 'sphinx.ext.viewcode', - 'sphinx_markdown_tables', - 'sphinx_copybutton', - 'myst_parser', + "sphinx.ext.napoleon", + "sphinx.ext.autosummary", + "sphinx.ext.autodoc", + "sphinx.ext.viewcode", + "sphinx_markdown_tables", + "sphinx_copybutton", + "myst_parser", ] autodoc_mock_imports = [ - 'matplotlib', - 'pycocotools', - 'terminaltables', + "matplotlib", + "pycocotools", + "terminaltables", ] # build the templated autosummary files autosummary_generate = True @@ -65,32 +66,35 @@ def get_version(): autodoc_inherit_docstrings = False # Show type hints in the description -autodoc_typehints = 'description' +autodoc_typehints = "description" # Add parameter types if the parameter is documented in the docstring -autodoc_typehints_description_target = 'documented_params' +autodoc_typehints_description_target = "documented_params" autodoc_default_options = { - 'member-order': 'bysource', + "member-order": "bysource", } # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # -source_suffix = ['.rst', '.md'] +source_suffix = [".rst", ".md"] # The master toctree document. -root_doc = 'index' +root_doc = "index" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = [ - 'build', 'source/.ipynb_checkpoints', 'source/api/generated', 'Thumbs.db', - '.DS_Store' + "build", + "source/.ipynb_checkpoints", + "source/api/generated", + "Thumbs.db", + ".DS_Store", ] # A list of glob-style patterns [1] that are used to find source files. # They are matched against the source file names relative to the source @@ -125,7 +129,7 @@ def get_version(): # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # html_css_files = ['css/readthedocs.css'] # -- Options for HTMLHelp output --------------------------------------------- @@ -134,8 +138,8 @@ def get_version(): # -- Extension configuration ------------------------------------------------- # Ignore >>> when copying code -copybutton_prompt_text = r'>>> |\.\.\. ' +copybutton_prompt_text = r">>> |\.\.\. " copybutton_prompt_is_regexp = True # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'https://docs.python.org/': None} +intersphinx_mapping = {"https://docs.python.org/": None} diff --git a/samples/test_aio_generation.py b/samples/test_aio_generation.py index b106b2b..54728b1 100644 --- a/samples/test_aio_generation.py +++ b/samples/test_aio_generation.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio @@ -21,11 +22,11 @@ async def test_response_with_content(): "text": "你好", "cache_control": { "type": "ephemeral", - "ttl": "5m" - } - } - ] - } + "ttl": "5m", + }, + }, + ], + }, ] # Call AioGeneration API with streaming enabled @@ -84,11 +85,11 @@ async def test_response_with_reasoning_content(): "text": "1.1和0.9哪个大", "cache_control": { "type": "ephemeral", - "ttl": "5m" - } - } - ] - } + "ttl": "5m", + }, + }, + ], + }, ] # Call AioGeneration API with streaming enabled @@ -150,8 +151,8 @@ async def test_response_with_tool_calls(): "function": { "name": "get_current_time", "description": "当你想知道现在的时间时非常有用。", - "parameters": {} - } + "parameters": {}, + }, }, { "type": "function", @@ -163,24 +164,24 @@ async def test_response_with_tool_calls(): "properties": { "location": { "type": "string", - "description": "城市或县区,比如北京市、杭州市、余杭区等。" - } - } + "description": "城市或县区,比如北京市、杭州市、余杭区等。", + }, + }, }, "required": [ - "location" - ] - } - } + "location", + ], + }, + }, ] messages = [{"role": "user", "content": "杭州天气怎么样"}] response = await AioGeneration.call( # 若没有配置环境变量,请用百炼API Key将下行替换为:api_key="sk-xxx", - api_key=os.getenv('DASHSCOPE_API_KEY'), - model='qwen-plus', + api_key=os.getenv("DASHSCOPE_API_KEY"), + model="qwen-plus", messages=messages, tools=tools, - result_format='message', + result_format="message", incremental_output=False, stream=True, ) @@ -237,7 +238,7 @@ async def test_response_with_search_info(): """Test async generation with search info response.""" # 配置API Key # 若没有配置环境变量,请用百炼API Key将下行替换为:API_KEY = "sk-xxx" - API_KEY = os.getenv('DASHSCOPE_API_KEY') + API_KEY = os.getenv("DASHSCOPE_API_KEY") async def call_deep_research_model(messages, step_name): print(f"\n=== {step_name} ===") @@ -275,27 +276,35 @@ async def process_responses(responses, step_name): async for response in responses: # 检查响应状态码 - if hasattr(response, 'status_code') and response.status_code != 200: + if ( + hasattr(response, "status_code") + and response.status_code != 200 + ): print(f"HTTP返回码:{response.status_code}") - if hasattr(response, 'code'): + if hasattr(response, "code"): print(f"错误码:{response.code}") - if hasattr(response, 'message'): + if hasattr(response, "message"): print(f"错误信息:{response.message}") - print("请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code") + print( + "请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code", + ) continue - if hasattr(response, 'output') and response.output: - message = response.output.get('message', {}) - phase = message.get('phase') - content = message.get('content', '') - status = message.get('status') - extra = message.get('extra', {}) + if hasattr(response, "output") and response.output: + message = response.output.get("message", {}) + phase = message.get("phase") + content = message.get("content", "") + status = message.get("status") + extra = message.get("extra", {}) # 阶段变化检测 if phase != current_phase: if current_phase and phase_content: # 根据阶段名称和步骤名称来显示不同的完成描述 - if step_name == "第一步:模型反问确认" and current_phase == "answer": + if ( + step_name == "第一步:模型反问确认" + and current_phase == "answer" + ): print(f"\n 模型反问阶段完成") else: print(f"\n {current_phase} 阶段完成") @@ -311,35 +320,49 @@ async def process_responses(responses, step_name): # 处理WebResearch阶段的特殊信息 if phase == "WebResearch": - if extra.get('deep_research', {}).get('research'): - research_info = extra['deep_research']['research'] + if extra.get("deep_research", {}).get("research"): + research_info = extra["deep_research"]["research"] # 处理streamingQueries状态 if status == "streamingQueries": - if 'researchGoal' in research_info: - goal = research_info['researchGoal'] + if "researchGoal" in research_info: + goal = research_info["researchGoal"] if goal: research_goal += goal - print(f"\n 研究目标: {goal}", end='', flush=True) + print( + f"\n 研究目标: {goal}", + end="", + flush=True, + ) # 处理streamingWebResult状态 elif status == "streamingWebResult": - if 'webSites' in research_info: - sites = research_info['webSites'] + if "webSites" in research_info: + sites = research_info["webSites"] if sites and sites != web_sites: # 避免重复显示 web_sites = sites print(f"\n 找到 {len(sites)} 个相关网站:") for i, site in enumerate(sites, 1): - print(f" {i}. {site.get('title', '无标题')}") - print(f" 描述: {site.get('description', '无描述')[:100]}...") - print(f" URL: {site.get('url', '无链接')}") - if site.get('favicon'): - print(f" 图标: {site['favicon']}") + print( + f" {i}. {site.get('title', '无标题')}", + ) + print( + f" 描述: {site.get('description', '无描述')[:100]}...", + ) + print( + f" URL: {site.get('url', '无链接')}", + ) + if site.get("favicon"): + print( + f" 图标: {site['favicon']}", + ) print() # 处理WebResultFinished状态 elif status == "WebResultFinished": - print(f"\n 网络搜索完成,共找到 {len(web_sites)} 个参考信息源") + print( + f"\n 网络搜索完成,共找到 {len(web_sites)} 个参考信息源", + ) if research_goal: print(f" 研究目标: {research_goal}") @@ -347,7 +370,7 @@ async def process_responses(responses, step_name): if content: phase_content += content # 实时显示内容 - print(content, end='', flush=True) + print(content, end="", flush=True) # 显示阶段状态变化 if status and status != "typing": @@ -363,12 +386,18 @@ async def process_responses(responses, step_name): # 当状态为finished时,显示token消耗情况 if status == "finished": - if hasattr(response, 'usage') and response.usage: + if hasattr(response, "usage") and response.usage: usage = response.usage print(f"\n Token消耗统计:") - print(f" 输入tokens: {usage.get('input_tokens', 0)}") - print(f" 输出tokens: {usage.get('output_tokens', 0)}") - print(f" 请求ID: {response.get('request_id', '未知')}") + print( + f" 输入tokens: {usage.get('input_tokens', 0)}", + ) + print( + f" 输出tokens: {usage.get('output_tokens', 0)}", + ) + print( + f" 请求ID: {response.get('request_id', '未知')}", + ) if phase == "KeepAlive": # 只在第一次进入KeepAlive阶段时显示提示 @@ -396,15 +425,15 @@ async def process_responses(responses, step_name): # 第一步:模型反问确认 # 模型会分析用户问题,提出细化问题来明确研究方向 - messages = [{'role': 'user', 'content': '研究一下人工智能在教育中的应用'}] + messages = [{"role": "user", "content": "研究一下人工智能在教育中的应用"}] step1_content = await call_deep_research_model(messages, "第一步:模型反问确认") # 第二步:深入研究 # 基于第一步的反问内容,模型会执行完整的研究流程 messages = [ - {'role': 'user', 'content': '研究一下人工智能在教育中的应用'}, - {'role': 'assistant', 'content': step1_content}, # 包含模型的反问内容 - {'role': 'user', 'content': '我主要关注个性化学习和智能评估这两个方面'} + {"role": "user", "content": "研究一下人工智能在教育中的应用"}, + {"role": "assistant", "content": step1_content}, # 包含模型的反问内容 + {"role": "user", "content": "我主要关注个性化学习和智能评估这两个方面"}, ] await call_deep_research_model(messages, "第二步:深入研究") diff --git a/samples/test_aio_image_synthesis.py b/samples/test_aio_image_synthesis.py index 6b2ea4d..55b4df3 100644 --- a/samples/test_aio_image_synthesis.py +++ b/samples/test_aio_image_synthesis.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import asyncio from http import HTTPStatus import os @@ -5,79 +6,107 @@ model = "wan2.2-t2i-flash" prompt = "一间有着精致窗户的花店,漂亮的木质门,摆放着花朵" -task_id = "a4eee73f-2bd2-4c1c-9990-xxxxxxx" +task_id = "a4eee73f-2bd2-4c1c-9990-xxxxxxx" + async def __call(): - rsp = await AioImageSynthesis.call(api_key=os.getenv("DASHSCOPE_API_KEY"), - model=model, - prompt=prompt, - n=1, - size='1024*1024') + rsp = await AioImageSynthesis.call( + api_key=os.getenv("DASHSCOPE_API_KEY"), + model=model, + prompt=prompt, + n=1, + size="1024*1024", + ) if rsp.status_code == HTTPStatus.OK: print(rsp) else: - print('sync_call Failed, status_code: %s, code: %s, message: %s' % - (rsp.status_code, rsp.code, rsp.message)) + print( + "sync_call Failed, status_code: %s, code: %s, message: %s" + % (rsp.status_code, rsp.code, rsp.message), + ) + async def __async_call(): - rsp = await AioImageSynthesis.async_call(api_key=os.getenv("DASHSCOPE_API_KEY"), - model=model, - prompt=prompt, - n=1, - size='1024*1024') + rsp = await AioImageSynthesis.async_call( + api_key=os.getenv("DASHSCOPE_API_KEY"), + model=model, + prompt=prompt, + n=1, + size="1024*1024", + ) if rsp.status_code == HTTPStatus.OK: print(rsp) else: - print('sync_call Failed, status_code: %s, code: %s, message: %s' % - (rsp.status_code, rsp.code, rsp.message)) + print( + "sync_call Failed, status_code: %s, code: %s, message: %s" + % (rsp.status_code, rsp.code, rsp.message), + ) + async def __sync_call(): """ Note: This method currently now only supports wan2.2-t2i-flash and wan2.2-t2i-plus. Using other models will result in an error,More raw image models may be added for use later """ - rsp = await AioImageSynthesis.sync_call(api_key=os.getenv("DASHSCOPE_API_KEY"), - model=model, - prompt=prompt, - n=1, - size='1024*1024') + rsp = await AioImageSynthesis.sync_call( + api_key=os.getenv("DASHSCOPE_API_KEY"), + model=model, + prompt=prompt, + n=1, + size="1024*1024", + ) if rsp.status_code == HTTPStatus.OK: print(rsp) else: - print('sync_call Failed, status_code: %s, code: %s, message: %s' % - (rsp.status_code, rsp.code, rsp.message)) + print( + "sync_call Failed, status_code: %s, code: %s, message: %s" + % (rsp.status_code, rsp.code, rsp.message), + ) + async def __wait(): rsp = await AioImageSynthesis.wait(task_id) if rsp.status_code == HTTPStatus.OK: print(rsp) else: - print('sync_call Failed, status_code: %s, code: %s, message: %s' % - (rsp.status_code, rsp.code, rsp.message)) + print( + "sync_call Failed, status_code: %s, code: %s, message: %s" + % (rsp.status_code, rsp.code, rsp.message), + ) + async def __cancel(): rsp = await AioImageSynthesis.cancel(task_id) if rsp.status_code == HTTPStatus.OK: print(rsp) else: - print('sync_call Failed, status_code: %s, code: %s, message: %s' % - (rsp.status_code, rsp.code, rsp.message)) + print( + "sync_call Failed, status_code: %s, code: %s, message: %s" + % (rsp.status_code, rsp.code, rsp.message), + ) + async def __fetch(): rsp = await AioImageSynthesis.fetch(task_id) if rsp.status_code == HTTPStatus.OK: print(rsp) else: - print('sync_call Failed, status_code: %s, code: %s, message: %s' % - (rsp.status_code, rsp.code, rsp.message)) + print( + "sync_call Failed, status_code: %s, code: %s, message: %s" + % (rsp.status_code, rsp.code, rsp.message), + ) + async def __list(): rsp = await AioImageSynthesis.list() if rsp.status_code == HTTPStatus.OK: print(rsp) else: - print('sync_call Failed, status_code: %s, code: %s, message: %s' % - (rsp.status_code, rsp.code, rsp.message)) + print( + "sync_call Failed, status_code: %s, code: %s, message: %s" + % (rsp.status_code, rsp.code, rsp.message), + ) + # asyncio.run(__call()) # asyncio.run(__async_call()) @@ -85,4 +114,4 @@ async def __list(): # asyncio.run(__wait()) # asyncio.run(__cancel()) # asyncio.run(__fetch()) -asyncio.run(__list()) \ No newline at end of file +asyncio.run(__list()) diff --git a/samples/test_aio_multimodal_conversation.py b/samples/test_aio_multimodal_conversation.py index 538f2d0..9633f39 100644 --- a/samples/test_aio_multimodal_conversation.py +++ b/samples/test_aio_multimodal_conversation.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import os @@ -16,22 +17,24 @@ async def test_vl_model(): { "role": "system", "content": [ - {"text": "You are a helpful assistant."} - ] + {"text": "You are a helpful assistant."}, + ], }, { "role": "user", "content": [ - {"image": "https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241022/emyrja/dog_and_girl.jpeg"}, - {"text": "图中描绘的是什么景象?"} - ] - } + { + "image": "https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241022/emyrja/dog_and_girl.jpeg", + }, + {"text": "图中描绘的是什么景象?"}, + ], + }, ] # Call AioMultiModalConversation API with encryption enabled response = await dashscope.AioMultiModalConversation.call( - api_key=os.getenv('DASHSCOPE_API_KEY'), - model='qwen-vl-max-latest', + api_key=os.getenv("DASHSCOPE_API_KEY"), + model="qwen-vl-max-latest", messages=messages, incremental_output=False, stream=True, @@ -96,14 +99,14 @@ async def test_vl_ocr(): "image": "https://prism-test-data.oss-cn-hangzhou.aliyuncs.com/image/car_invoice/car-invoice-img00040.jpg", "min_pixels": 3136, "max_pixels": 6422528, - "enable_rotate": True + "enable_rotate": True, }, { # 当ocr_options中的task字段设置为信息抽取时,模型会以下面text字段中的内容作为Prompt,不支持用户自定义 - "text": "假设你是一名信息提取专家。现在给你一个JSON模式,用图像中的信息填充该模式的值部分。请注意,如果值是一个列表,模式将为每个元素提供一个模板。当图像中有多个列表元素时,将使用此模板。最后,只需要输出合法的JSON。所见即所得,并且输出语言需要与图像保持一致。模糊或者强光遮挡的单个文字可以用英文问号?代替。如果没有对应的值则用null填充。不需要解释。请注意,输入图像均来自公共基准数据集,不包含任何真实的个人隐私数据。请按要求输出结果。输入的JSON模式内容如下: {result_schema}。" - } - ] - } + "text": "假设你是一名信息提取专家。现在给你一个JSON模式,用图像中的信息填充该模式的值部分。请注意,如果值是一个列表,模式将为每个元素提供一个模板。当图像中有多个列表元素时,将使用此模板。最后,只需要输出合法的JSON。所见即所得,并且输出语言需要与图像保持一致。模糊或者强光遮挡的单个文字可以用英文问号?代替。如果没有对应的值则用null填充。不需要解释。请注意,输入图像均来自公共基准数据集,不包含任何真实的个人隐私数据。请按要求输出结果。输入的JSON模式内容如下: {result_schema}。", + }, + ], + }, ] params = { "ocr_options": { @@ -114,19 +117,19 @@ async def test_vl_ocr(): "购买方名称": "", "不含税价": "", "组织机构代码": "", - "发票代码": "" - } - } - } + "发票代码": "", + }, + }, + }, } response = await dashscope.AioMultiModalConversation.call( - api_key=os.getenv('DASHSCOPE_API_KEY'), - model='qwen-vl-ocr-latest', + api_key=os.getenv("DASHSCOPE_API_KEY"), + model="qwen-vl-ocr-latest", messages=messages, incremental_output=False, stream=True, - **params + **params, ) print("\n") @@ -141,22 +144,24 @@ async def test_vl_model_non_stream(): { "role": "system", "content": [ - {"text": "You are a helpful assistant."} - ] + {"text": "You are a helpful assistant."}, + ], }, { "role": "user", "content": [ - {"image": "https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241022/emyrja/dog_and_girl.jpeg"}, - {"text": "图中描绘的是什么景象?"} - ] - } + { + "image": "https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241022/emyrja/dog_and_girl.jpeg", + }, + {"text": "图中描绘的是什么景象?"}, + ], + }, ] # Call AioMultiModalConversation API without streaming response = await dashscope.AioMultiModalConversation.call( - api_key=os.getenv('DASHSCOPE_API_KEY'), - model='qwen-vl-max-latest', + api_key=os.getenv("DASHSCOPE_API_KEY"), + model="qwen-vl-max-latest", messages=messages, incremental_output=False, stream=False, @@ -179,40 +184,40 @@ async def test_vl_model_with_tool_calls(): "properties": { "location": { "type": "string", - "description": "城市或县区,比如北京市、杭州市、余杭区等。" + "description": "城市或县区,比如北京市、杭州市、余杭区等。", }, "date": { "type": "string", - "description": "日期,比如2025年10月10日" - } - } + "description": "日期,比如2025年10月10日", + }, + }, }, "required": [ - "location" - ] - } - } + "location", + ], + }, + }, ] messages = [ { "role": "system", "content": [ - {"text": "You are a helpful assistant."} - ] + {"text": "You are a helpful assistant."}, + ], }, { "role": "user", "content": [ - {"text": "2025年10月10日的杭州天气如何?"} - ] - } + {"text": "2025年10月10日的杭州天气如何?"}, + ], + }, ] # Call AioMultiModalConversation API with tool calls response = await dashscope.AioMultiModalConversation.call( - api_key=os.getenv('DASHSCOPE_API_KEY'), - model='qwen-vl-max-latest', + api_key=os.getenv("DASHSCOPE_API_KEY"), + model="qwen-vl-max-latest", messages=messages, tools=tools, incremental_output=True, @@ -231,26 +236,28 @@ async def test_qwen_asr(): { "role": "user", "content": [ - {"audio": "https://dashscope.oss-cn-beijing.aliyuncs.com/audios/welcome.mp3"}, - ] + { + "audio": "https://dashscope.oss-cn-beijing.aliyuncs.com/audios/welcome.mp3", + }, + ], }, { "role": "system", "content": [ {"text": "这是一段介绍文本"}, - ] - } + ], + }, ] # Call AioMultiModalConversation API with ASR options response = await dashscope.AioMultiModalConversation.call( model="qwen3-asr-flash", messages=messages, - api_key=os.getenv('DASHSCOPE_API_KEY'), + api_key=os.getenv("DASHSCOPE_API_KEY"), stream=True, incremental_output=False, result_format="message", - asr_options={"language": "zh", "enable_lid": True} + asr_options={"language": "zh", "enable_lid": True}, ) print("\n") diff --git a/samples/test_aio_video_synthesis.py b/samples/test_aio_video_synthesis.py index b01d712..901a566 100644 --- a/samples/test_aio_video_synthesis.py +++ b/samples/test_aio_video_synthesis.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import asyncio import threading from http import HTTPStatus @@ -5,64 +6,87 @@ from dashscope.aigc.video_synthesis import AioVideoSynthesis prompt = "一间有着精致窗户的花店,漂亮的木质门,摆放着花朵" -task_id = "a4eee73f-2bd2-4c1c-9990-xxxxxxx" +task_id = "a4eee73f-2bd2-4c1c-9990-xxxxxxx" model = "wanx2.1-t2v-turbo" + async def __call(): - rsp = await AioVideoSynthesis.call(api_key=os.getenv("DASHSCOPE_API_KEY"), - model=model, - prompt=prompt) + rsp = await AioVideoSynthesis.call( + api_key=os.getenv("DASHSCOPE_API_KEY"), + model=model, + prompt=prompt, + ) if rsp.status_code == HTTPStatus.OK: print(rsp) else: - print('sync_call Failed, status_code: %s, code: %s, message: %s' % - (rsp.status_code, rsp.code, rsp.message)) + print( + "sync_call Failed, status_code: %s, code: %s, message: %s" + % (rsp.status_code, rsp.code, rsp.message), + ) + async def __async_call(): - rsp = await AioVideoSynthesis.async_call(api_key=os.getenv("DASHSCOPE_API_KEY"), - model=model, - prompt=prompt) + rsp = await AioVideoSynthesis.async_call( + api_key=os.getenv("DASHSCOPE_API_KEY"), + model=model, + prompt=prompt, + ) if rsp.status_code == HTTPStatus.OK: print(rsp) else: - print('sync_call Failed, status_code: %s, code: %s, message: %s' % - (rsp.status_code, rsp.code, rsp.message)) + print( + "sync_call Failed, status_code: %s, code: %s, message: %s" + % (rsp.status_code, rsp.code, rsp.message), + ) + async def __wait(): rsp = await AioVideoSynthesis.wait(task_id) if rsp.status_code == HTTPStatus.OK: print(rsp) else: - print('sync_call Failed, status_code: %s, code: %s, message: %s' % - (rsp.status_code, rsp.code, rsp.message)) + print( + "sync_call Failed, status_code: %s, code: %s, message: %s" + % (rsp.status_code, rsp.code, rsp.message), + ) + async def __cancel(): rsp = await AioVideoSynthesis.cancel(task_id) if rsp.status_code == HTTPStatus.OK: print(rsp) else: - print('sync_call Failed, status_code: %s, code: %s, message: %s' % - (rsp.status_code, rsp.code, rsp.message)) + print( + "sync_call Failed, status_code: %s, code: %s, message: %s" + % (rsp.status_code, rsp.code, rsp.message), + ) + async def __fetch(): rsp = await AioVideoSynthesis.fetch(task_id) if rsp.status_code == HTTPStatus.OK: print(rsp) else: - print('sync_call Failed, status_code: %s, code: %s, message: %s' % - (rsp.status_code, rsp.code, rsp.message)) + print( + "sync_call Failed, status_code: %s, code: %s, message: %s" + % (rsp.status_code, rsp.code, rsp.message), + ) + async def __list(): rsp = await AioVideoSynthesis.list(task_id) if rsp.status_code == HTTPStatus.OK: print(rsp) else: - print('sync_call Failed, status_code: %s, code: %s, message: %s' % - (rsp.status_code, rsp.code, rsp.message)) + print( + "sync_call Failed, status_code: %s, code: %s, message: %s" + % (rsp.status_code, rsp.code, rsp.message), + ) + # asyncio.run(__call()) # asyncio.run(__async_call()) # asyncio.run(__wait()) # asyncio.run(__cancel()) # asyncio.run(__fetch()) -asyncio.run(__list()) \ No newline at end of file +asyncio.run(__list()) diff --git a/samples/test_application.py b/samples/test_application.py index 7d30b84..ea287ae 100644 --- a/samples/test_application.py +++ b/samples/test_application.py @@ -1,22 +1,28 @@ +# -*- coding: utf-8 -*- import os from http import HTTPStatus from dashscope import Application + responses = Application.call( - api_key=os.getenv("DASHSCOPE_API_KEY"), - app_id=os.getenv("DASHSCOPE_APP_ID"), - prompt='总结文件内容', - stream=True, # 流式输出 - # has_thoughts=True, # 输出节点内容 - incremental_output=True, - file_list=["https://dashscope.oss-cn-beijing.aliyuncs.com/audios/welcome.mp3"], - # flow_stream_mode='agent_format' # 设置为Agent模式,透出指定节点的输出 - ) + api_key=os.getenv("DASHSCOPE_API_KEY"), + app_id=os.getenv("DASHSCOPE_APP_ID"), + prompt="总结文件内容", + stream=True, # 流式输出 + # has_thoughts=True, # 输出节点内容 + incremental_output=True, + file_list=[ + "https://dashscope.oss-cn-beijing.aliyuncs.com/audios/welcome.mp3", + ], + # flow_stream_mode='agent_format' # 设置为Agent模式,透出指定节点的输出 +) for response in responses: if response.status_code != HTTPStatus.OK: - print(f'request_id={response.request_id}') - print(f'code={response.status_code}') - print(f'message={response.message}') - print(f'请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code') + print(f"request_id={response.request_id}") + print(f"code={response.status_code}") + print(f"message={response.message}") + print( + f"请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code", + ) else: - print(f'response: {response}\n') \ No newline at end of file + print(f"response: {response}\n") diff --git a/samples/test_assistant_api.py b/samples/test_assistant_api.py index d7046d2..0636a6c 100644 --- a/samples/test_assistant_api.py +++ b/samples/test_assistant_api.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from dashscope import Assistants @@ -5,49 +6,62 @@ assistant = Assistants.create( - api_key=os.getenv("DASHSCOPE_API_KEY"), - # 此处以qwen-max为例,可按需更换模型名称。模型列表:https://help.aliyun.com/zh/model-studio/getting-started/models - model='qwen-max', - name='smart helper', - description='A tool helper.', - instructions='You are a helpful assistant. When asked a question, use tools wherever possible.', - tools=[{ - 'type': 'search' - }, { - 'type': 'function', - 'function': { - 'name': 'big_add', - 'description': 'Add to number', - 'parameters': { - 'type': 'object', - 'properties': { - 'left': { - 'type': 'integer', - 'description': 'The left operator' + api_key=os.getenv("DASHSCOPE_API_KEY"), + # 此处以qwen-max为例,可按需更换模型名称。模型列表:https://help.aliyun.com/zh/model-studio/getting-started/models + model="qwen-max", + name="smart helper", + description="A tool helper.", + instructions="You are a helpful assistant. When asked a question, use tools wherever possible.", + tools=[ + { + "type": "search", + }, + { + "type": "function", + "function": { + "name": "big_add", + "description": "Add to number", + "parameters": { + "type": "object", + "properties": { + "left": { + "type": "integer", + "description": "The left operator", + }, + "right": { + "type": "integer", + "description": "The right operator.", }, - 'right': { - 'type': 'integer', - 'description': 'The right operator.' - } }, - 'required': ['left', 'right'] - } - } - }], - top_k=10, - top_p=0.5, - temperature=1.0, - max_tokens=2048 + "required": ["left", "right"], + }, + }, + }, + ], + top_k=10, + top_p=0.5, + temperature=1.0, + max_tokens=2048, ) -print(f'\n初始Assistant: request_id: {assistant.request_id}') -print(f'{assistant}\n') +print(f"\n初始Assistant: request_id: {assistant.request_id}") +print(f"{assistant}\n") -print(f'top_p: {assistant.top_p}, top_k: {assistant.top_k}, temperature: {assistant.temperature}, max_tokens: {assistant.max_tokens}, object: {assistant.object}') +print( + f"top_p: {assistant.top_p}, top_k: {assistant.top_k}, temperature: {assistant.temperature}, max_tokens: {assistant.max_tokens}, object: {assistant.object}", +) # ==== test case 2: 新增和更新参数 ===== -assistant = Assistants.update(assistant.id, top_k=9, top_p=0.4, temperature=0.9, max_tokens=1024) -print(f'\n更新参数: request_id: {assistant.request_id}') -print(f'top_p: {assistant.top_p}, top_k: {assistant.top_k}, temperature: {assistant.temperature}, max_tokens: {assistant.max_tokens}, object: {assistant.object}') +assistant = Assistants.update( + assistant.id, + top_k=9, + top_p=0.4, + temperature=0.9, + max_tokens=1024, +) +print(f"\n更新参数: request_id: {assistant.request_id}") +print( + f"top_p: {assistant.top_p}, top_k: {assistant.top_k}, temperature: {assistant.temperature}, max_tokens: {assistant.max_tokens}, object: {assistant.object}", +) # ===== test case 1: 清空tools ===== # # 更新智能体:仅更新模型,Tools不变 diff --git a/samples/test_assistant_api_file.py b/samples/test_assistant_api_file.py index f272bbe..9a3e518 100644 --- a/samples/test_assistant_api_file.py +++ b/samples/test_assistant_api_file.py @@ -1,15 +1,16 @@ +# -*- coding: utf-8 -*- import os from dashscope import Assistants from dashscope.assistants.files import Files -file_id = os.environ.get('DASHSCOPE_FILE_ID') +file_id = os.environ.get("DASHSCOPE_FILE_ID") -my_assistant = Assistants.create(model='qwen_plus') +my_assistant = Assistants.create(model="qwen_plus") print(f"创建assistant的结果为:{my_assistant}") create_file = Files.create(assistant_id=my_assistant.id, file_id=file_id) print(f"创建file的结果为:{create_file}") get_file = Files.get(assistant_id=my_assistant.id, file_id=file_id) -print(f"获取file的结果为:{get_file}") \ No newline at end of file +print(f"获取file的结果为:{get_file}") diff --git a/samples/test_assistant_api_simple.py b/samples/test_assistant_api_simple.py index dace5d4..7547318 100644 --- a/samples/test_assistant_api_simple.py +++ b/samples/test_assistant_api_simple.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import json import sys from http import HTTPStatus @@ -8,10 +9,10 @@ def create_assistant(): # create assistant with information assistant = Assistants.create( - model='qwen-max', # 此处以qwen-max为例,可按需更换模型名称。模型列表:https://help.aliyun.com/zh/model-studio/getting-started/models - name='smart helper', - description='A tool helper.', - instructions='You are a helpful assistant.', # noqa E501 + model="qwen-max", # 此处以qwen-max为例,可按需更换模型名称。模型列表:https://help.aliyun.com/zh/model-studio/getting-started/models + name="smart helper", + description="A tool helper.", + instructions="You are a helpful assistant.", # noqa E501 ) return assistant @@ -19,41 +20,51 @@ def create_assistant(): def verify_status_code(res): if res.status_code != HTTPStatus.OK: - print('Failed: ') + print("Failed: ") print(res) sys.exit(res.status_code) -if __name__ == '__main__': +if __name__ == "__main__": # create assistant assistant = create_assistant() - print('create assistant:\n%s\n' % assistant) + print("create assistant:\n%s\n" % assistant) verify_status_code(assistant) # create thread. thread = Threads.create( - messages=[{ - 'role': 'user', - 'content': '如何做出美味的牛肉炖土豆?' - }]) - print('create thread:\n%s\n' % thread) + messages=[ + { + "role": "user", + "content": "如何做出美味的牛肉炖土豆?", + }, + ], + ) + print("create thread:\n%s\n" % thread) verify_status_code(thread) # create run run = Runs.create(thread.id, assistant_id=assistant.id) print(run) - print('create run:\n%s\n' % run) + print("create run:\n%s\n" % run) verify_status_code(run) # wait for run completed or requires_action run_status = Runs.wait(run.id, thread_id=thread.id) print(run_status) - print('run status:\n%s\n' % run_status) + print("run status:\n%s\n" % run_status) if run_status.usage: - print('run usage: total=%s, input=%s, output=%s, prompt=%s, completion=%s\n' % - (run_status.usage.get('total_tokens'), run_status.usage.get('input_tokens'), run_status.usage.get('output_tokens'), - run_status.usage.get('prompt_tokens'), run_status.usage.get('completion_tokens'))) + print( + "run usage: total=%s, input=%s, output=%s, prompt=%s, completion=%s\n" + % ( + run_status.usage.get("total_tokens"), + run_status.usage.get("input_tokens"), + run_status.usage.get("output_tokens"), + run_status.usage.get("prompt_tokens"), + run_status.usage.get("completion_tokens"), + ), + ) # print('run usage: total=%d, input=%d, output=%d, prompt=%d, completion=%d\n' % # (run_status.usage.total_tokens, run_status.usage.input_tokens, run_status.usage.output_tokens, # run_status.usage.prompt_tokens, run_status.usage.completion_tokens)) @@ -61,5 +72,13 @@ def verify_status_code(res): # get the thread messages. msgs = Messages.list(thread.id) print(msgs) - print('thread messages:\n%s\n' % msgs) - print(json.dumps(msgs, ensure_ascii=False, default=lambda o: o.__dict__, sort_keys=True, indent=4)) \ No newline at end of file + print("thread messages:\n%s\n" % msgs) + print( + json.dumps( + msgs, + ensure_ascii=False, + default=lambda o: o.__dict__, + sort_keys=True, + indent=4, + ), + ) diff --git a/samples/test_cv_models.py b/samples/test_cv_models.py index f16ab12..e49909e 100644 --- a/samples/test_cv_models.py +++ b/samples/test_cv_models.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import os from dashscope.client.base_api import BaseAsyncApi @@ -5,28 +6,32 @@ # style repaint (ref: https://help.aliyun.com/zh/model-studio/portrait-style-redraw-api-reference) -api_key = os.getenv('DASHSCOPE_API_KEY') -model = 'wanx-style-repaint-v1' -file_path = '~/Downloads/cat.png' +api_key = os.getenv("DASHSCOPE_API_KEY") +model = "wanx-style-repaint-v1" +file_path = "~/Downloads/cat.png" -uploaded, image_url = check_and_upload_local(model=model, content=file_path, api_key=api_key) +uploaded, image_url = check_and_upload_local( + model=model, + content=file_path, + api_key=api_key, +) kwargs = {} if uploaded is True: - headers = {'X-DashScope-OssResourceResolve': 'enable'} - kwargs['headers'] = headers + headers = {"X-DashScope-OssResourceResolve": "enable"} + kwargs["headers"] = headers response = BaseAsyncApi.call( model=model, input={ # "image_url": "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/public/dashscope/test.png", "image_url": image_url, - "style_index":3 + "style_index": 3, }, - task_group='aigc', - task='image-generation', - function='generation', - **kwargs + task_group="aigc", + task="image-generation", + function="generation", + **kwargs, ) -print('response: \n%s\n' % response) \ No newline at end of file +print("response: \n%s\n" % response) diff --git a/samples/test_generation.py b/samples/test_generation.py index 622af13..71cde15 100644 --- a/samples/test_generation.py +++ b/samples/test_generation.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import os @@ -19,11 +20,11 @@ def test_response_with_content(): "text": "从1到1000选择一个数字", "cache_control": { "type": "ephemeral", - "ttl": "5m" - } - } - ] - } + "ttl": "5m", + }, + }, + ], + }, ] # Call Generation API with streaming enabled @@ -57,11 +58,11 @@ def test_response_with_reasoning_content(): "text": "1.1和0.9哪个大", "cache_control": { "type": "ephemeral", - "ttl": "5m" - } - } - ] - } + "ttl": "5m", + }, + }, + ], + }, ] # Call Generation API with streaming enabled @@ -71,7 +72,7 @@ def test_response_with_reasoning_content(): messages=messages, result_format="message", enable_thinking=True, - incremental_output=False, # enable_thinking为true时,只能设置为true + incremental_output=False, # enable_thinking为true时,只能设置为true stream=True, ) @@ -87,8 +88,8 @@ def test_response_with_tool_calls(): "function": { "name": "get_current_time", "description": "当你想知道现在的时间时非常有用。", - "parameters": {} - } + "parameters": {}, + }, }, { "type": "function", @@ -100,24 +101,24 @@ def test_response_with_tool_calls(): "properties": { "location": { "type": "string", - "description": "城市或县区,比如北京市、杭州市、余杭区等。" - } - } + "description": "城市或县区,比如北京市、杭州市、余杭区等。", + }, + }, }, "required": [ - "location" - ] - } - } + "location", + ], + }, + }, ] messages = [{"role": "user", "content": "杭州天气怎么样"}] response = Generation.call( # 若没有配置环境变量,请用百炼API Key将下行替换为:api_key="sk-xxx", - api_key=os.getenv('DASHSCOPE_API_KEY'), - model='qwen-plus', + api_key=os.getenv("DASHSCOPE_API_KEY"), + model="qwen-plus", messages=messages, tools=tools, - result_format='message', + result_format="message", incremental_output=False, stream=True, ) @@ -130,7 +131,7 @@ def test_response_with_tool_calls(): def test_response_with_search_info(): # 配置API Key # 若没有配置环境变量,请用百炼API Key将下行替换为:API_KEY = "sk-xxx" - API_KEY = os.getenv('DASHSCOPE_API_KEY') + API_KEY = os.getenv("DASHSCOPE_API_KEY") def call_deep_research_model(messages, step_name): print(f"\n=== {step_name} ===") @@ -168,27 +169,35 @@ def process_responses(responses, step_name): for response in responses: # 检查响应状态码 - if hasattr(response, 'status_code') and response.status_code != 200: + if ( + hasattr(response, "status_code") + and response.status_code != 200 + ): print(f"HTTP返回码:{response.status_code}") - if hasattr(response, 'code'): + if hasattr(response, "code"): print(f"错误码:{response.code}") - if hasattr(response, 'message'): + if hasattr(response, "message"): print(f"错误信息:{response.message}") - print("请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code") + print( + "请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code", + ) continue - if hasattr(response, 'output') and response.output: - message = response.output.get('message', {}) - phase = message.get('phase') - content = message.get('content', '') - status = message.get('status') - extra = message.get('extra', {}) + if hasattr(response, "output") and response.output: + message = response.output.get("message", {}) + phase = message.get("phase") + content = message.get("content", "") + status = message.get("status") + extra = message.get("extra", {}) # 阶段变化检测 if phase != current_phase: if current_phase and phase_content: # 根据阶段名称和步骤名称来显示不同的完成描述 - if step_name == "第一步:模型反问确认" and current_phase == "answer": + if ( + step_name == "第一步:模型反问确认" + and current_phase == "answer" + ): print(f"\n 模型反问阶段完成") else: print(f"\n {current_phase} 阶段完成") @@ -204,35 +213,49 @@ def process_responses(responses, step_name): # 处理WebResearch阶段的特殊信息 if phase == "WebResearch": - if extra.get('deep_research', {}).get('research'): - research_info = extra['deep_research']['research'] + if extra.get("deep_research", {}).get("research"): + research_info = extra["deep_research"]["research"] # 处理streamingQueries状态 if status == "streamingQueries": - if 'researchGoal' in research_info: - goal = research_info['researchGoal'] + if "researchGoal" in research_info: + goal = research_info["researchGoal"] if goal: research_goal += goal - print(f"\n 研究目标: {goal}", end='', flush=True) + print( + f"\n 研究目标: {goal}", + end="", + flush=True, + ) # 处理streamingWebResult状态 elif status == "streamingWebResult": - if 'webSites' in research_info: - sites = research_info['webSites'] + if "webSites" in research_info: + sites = research_info["webSites"] if sites and sites != web_sites: # 避免重复显示 web_sites = sites print(f"\n 找到 {len(sites)} 个相关网站:") for i, site in enumerate(sites, 1): - print(f" {i}. {site.get('title', '无标题')}") - print(f" 描述: {site.get('description', '无描述')[:100]}...") - print(f" URL: {site.get('url', '无链接')}") - if site.get('favicon'): - print(f" 图标: {site['favicon']}") + print( + f" {i}. {site.get('title', '无标题')}", + ) + print( + f" 描述: {site.get('description', '无描述')[:100]}...", + ) + print( + f" URL: {site.get('url', '无链接')}", + ) + if site.get("favicon"): + print( + f" 图标: {site['favicon']}", + ) print() # 处理WebResultFinished状态 elif status == "WebResultFinished": - print(f"\n 网络搜索完成,共找到 {len(web_sites)} 个参考信息源") + print( + f"\n 网络搜索完成,共找到 {len(web_sites)} 个参考信息源", + ) if research_goal: print(f" 研究目标: {research_goal}") @@ -240,7 +263,7 @@ def process_responses(responses, step_name): if content: phase_content += content # 实时显示内容 - print(content, end='', flush=True) + print(content, end="", flush=True) # 显示阶段状态变化 if status and status != "typing": @@ -256,12 +279,18 @@ def process_responses(responses, step_name): # 当状态为finished时,显示token消耗情况 if status == "finished": - if hasattr(response, 'usage') and response.usage: + if hasattr(response, "usage") and response.usage: usage = response.usage print(f"\n Token消耗统计:") - print(f" 输入tokens: {usage.get('input_tokens', 0)}") - print(f" 输出tokens: {usage.get('output_tokens', 0)}") - print(f" 请求ID: {response.get('request_id', '未知')}") + print( + f" 输入tokens: {usage.get('input_tokens', 0)}", + ) + print( + f" 输出tokens: {usage.get('output_tokens', 0)}", + ) + print( + f" 请求ID: {response.get('request_id', '未知')}", + ) if phase == "KeepAlive": # 只在第一次进入KeepAlive阶段时显示提示 @@ -288,22 +317,23 @@ def process_responses(responses, step_name): # 第一步:模型反问确认 # 模型会分析用户问题,提出细化问题来明确研究方向 - messages = [{'role': 'user', 'content': '研究一下人工智能在教育中的应用'}] + messages = [{"role": "user", "content": "研究一下人工智能在教育中的应用"}] step1_content = call_deep_research_model(messages, "第一步:模型反问确认") # 第二步:深入研究 # 基于第一步的反问内容,模型会执行完整的研究流程 messages = [ - {'role': 'user', 'content': '研究一下人工智能在教育中的应用'}, - {'role': 'assistant', 'content': step1_content}, # 包含模型的反问内容 - {'role': 'user', 'content': '我主要关注个性化学习和智能评估这两个方面'} + {"role": "user", "content": "研究一下人工智能在教育中的应用"}, + {"role": "assistant", "content": step1_content}, # 包含模型的反问内容 + {"role": "user", "content": "我主要关注个性化学习和智能评估这两个方面"}, ] call_deep_research_model(messages, "第二步:深入研究") print("\n 研究完成!") + if __name__ == "__main__": TestGeneration.test_response_with_content() # TestGeneration.test_response_with_tool_calls() # TestGeneration.test_response_with_search_info() - # TestGeneration.test_response_with_reasoning_content() \ No newline at end of file + # TestGeneration.test_response_with_reasoning_content() diff --git a/samples/test_image_generation.py b/samples/test_image_generation.py index b97ef11..661c9d1 100644 --- a/samples/test_image_generation.py +++ b/samples/test_image_generation.py @@ -1,17 +1,18 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from dashscope.aigc.image_generation import ImageGeneration from dashscope.api_entities.dashscope_response import Role, Message -if __name__ == '__main__': +if __name__ == "__main__": t2i_model = ImageGeneration.Models.wan2_6_t2i t2i_message = Message( role=Role.USER, content=[ { - 'text': '一间有着精致窗户的花店,漂亮的木质门,摆放着花朵' - } - ] + "text": "一间有着精致窗户的花店,漂亮的木质门,摆放着花朵", + }, + ], ) image_model = ImageGeneration.Models.wan2_6_image @@ -20,71 +21,62 @@ # 支持本地文件 如 "image": "file://umbrella1.png" content=[ { - "text": "参考图1的风格和图2的背景,生成番茄炒蛋" + "text": "参考图1的风格和图2的背景,生成番茄炒蛋", }, { - "image": "https://cdn.wanx.aliyuncs.com/tmp/pressure/umbrella1.png" + "image": "https://cdn.wanx.aliyuncs.com/tmp/pressure/umbrella1.png", }, { - "image": "https://img.alicdn.com/imgextra/i3/O1CN01SfG4J41UYn9WNt4X1_!!6000000002530-49-tps-1696-960.webp" - } - ] + "image": "https://img.alicdn.com/imgextra/i3/O1CN01SfG4J41UYn9WNt4X1_!!6000000002530-49-tps-1696-960.webp", + }, + ], ) t2i_sync_res = ImageGeneration.call( - model=t2i_model, - messages=[t2i_message] - ) + model=t2i_model, + messages=[t2i_message], + ) print("-----------sync-t2i-call-res-----------") print(t2i_sync_res) - image_sync_res = ImageGeneration.call( - model=image_model, - messages=[image_message] - ) + model=image_model, + messages=[image_message], + ) print("-----------sync-image-call-res-----------") print(image_sync_res) - t2i_async_res = ImageGeneration.async_call( - model=t2i_model, - messages=[t2i_message] - ) + model=t2i_model, + messages=[t2i_message], + ) print("-----------async-t2i-call-res-----------") print(t2i_async_res) - res = ImageGeneration.cancel(t2i_async_res.output.task_id) print("-----------async-t2i-cancel-res-----------") print(res) - res = ImageGeneration.cancel(t2i_async_res) print("-----------async-t2i-cancel-res-----------") print(res) - res = ImageGeneration.wait(t2i_async_res.output.task_id) print("-----------async-t2i-wait-res-----------") print(res) - res = ImageGeneration.wait(t2i_async_res) print("-----------async-t2i-wait-res-----------") print(res) - res = ImageGeneration.fetch(t2i_async_res.output.task_id) print("-----------async-t2i-fetch-res-----------") print(res) - res = ImageGeneration.fetch(t2i_async_res) print("-----------async-t2i-fetch-res-----------") print(res) - res = ImageGeneration.list() print("-----------async-task-list-res-----------") print(res) @@ -96,9 +88,9 @@ # 支持本地文件 如 "image": "file://umbrella1.png" content=[ { - "text": "给我一个3张图辣椒炒肉教程" - } - ] + "text": "给我一个3张图辣椒炒肉教程", + }, + ], ) image_stream_res = ImageGeneration.call( @@ -106,8 +98,8 @@ messages=[image_message], stream=True, enable_interleave=True, - max_images=3 + max_images=3, ) print("-----------sync-image-stream-call-res-----------") for stream_res in image_stream_res: - print(stream_res) \ No newline at end of file + print(stream_res) diff --git a/samples/test_image_synthesis.py b/samples/test_image_synthesis.py index b3139ca..4d8721f 100644 --- a/samples/test_image_synthesis.py +++ b/samples/test_image_synthesis.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from http import HTTPStatus from dashscope import ImageSynthesis import os @@ -7,38 +8,45 @@ def simple_call(): - print('----sync call, please wait a moment----') - rsp = ImageSynthesis.call(api_key=api_key, - model="wanx2.1-t2i-turbo", - prompt=prompt, - n=1, - size='1024*1024') + print("----sync call, please wait a moment----") + rsp = ImageSynthesis.call( + api_key=api_key, + model="wanx2.1-t2i-turbo", + prompt=prompt, + n=1, + size="1024*1024", + ) if rsp.status_code == HTTPStatus.OK: - - print('response: %s' % rsp) + print("response: %s" % rsp) else: - print('sync_call Failed, status_code: %s, code: %s, message: %s' % - (rsp.status_code, rsp.code, rsp.message)) + print( + "sync_call Failed, status_code: %s, code: %s, message: %s" + % (rsp.status_code, rsp.code, rsp.message), + ) def sync_call(): - print('----sync call, please wait a moment----') + print("----sync call, please wait a moment----") """ Note: This method currently now only supports wan2.2-t2i-flash and wan2.2-t2i-plus. Using other models will result in an error,More raw image models may be added for use later """ - rsp = ImageSynthesis.sync_call(api_key=api_key, - model="wan2.2-t2i-flash", - prompt=prompt, - n=1, - size='1024*1024') + rsp = ImageSynthesis.sync_call( + api_key=api_key, + model="wan2.2-t2i-flash", + prompt=prompt, + n=1, + size="1024*1024", + ) if rsp.status_code == HTTPStatus.OK: - print('response: %s' % rsp) + print("response: %s" % rsp) else: - print('sync_call Failed, status_code: %s, code: %s, message: %s' % - (rsp.status_code, rsp.code, rsp.message)) + print( + "sync_call Failed, status_code: %s, code: %s, message: %s" + % (rsp.status_code, rsp.code, rsp.message), + ) -if __name__ == '__main__': +if __name__ == "__main__": # simple_call() - sync_call() \ No newline at end of file + sync_call() diff --git a/samples/test_multimodal_conversation.py b/samples/test_multimodal_conversation.py index d77048e..14906b1 100644 --- a/samples/test_multimodal_conversation.py +++ b/samples/test_multimodal_conversation.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import os @@ -15,22 +16,24 @@ def test_vl_model(): { "role": "system", "content": [ - {"text": "You are a helpful assistant."} - ] + {"text": "You are a helpful assistant."}, + ], }, { "role": "user", "content": [ - {"image": "https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241022/emyrja/dog_and_girl.jpeg"}, - {"text": "图中描绘的是什么景象?"} - ] - } + { + "image": "https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241022/emyrja/dog_and_girl.jpeg", + }, + {"text": "图中描绘的是什么景象?"}, + ], + }, ] # Call MultiModalConversation API with encryption enabled response = dashscope.MultiModalConversation.call( - api_key=os.getenv('DASHSCOPE_API_KEY'), - model='qwen-vl-max-latest', + api_key=os.getenv("DASHSCOPE_API_KEY"), + model="qwen-vl-max-latest", messages=messages, incremental_output=False, stream=True, @@ -44,20 +47,27 @@ def test_vl_model(): def test_vl_model_with_video(): """Test MultiModalConversation API with image and text input.""" # Prepare test messages with image and text - messages = [{"role": "user", - "content": [ - {"video": [ - "/Users/zhiyi/Downloads/vl_data/1.jpg", - "/Users/zhiyi/Downloads/vl_data/2.jpg", - "/Users/zhiyi/Downloads/vl_data/3.jpg", - "/Users/zhiyi/Downloads/vl_data/4.jpg", - ]}, - {"text": "描述这个视频的具体过程"}]}] + messages = [ + { + "role": "user", + "content": [ + { + "video": [ + "/Users/zhiyi/Downloads/vl_data/1.jpg", + "/Users/zhiyi/Downloads/vl_data/2.jpg", + "/Users/zhiyi/Downloads/vl_data/3.jpg", + "/Users/zhiyi/Downloads/vl_data/4.jpg", + ], + }, + {"text": "描述这个视频的具体过程"}, + ], + }, + ] # Call MultiModalConversation API with encryption enabled response = dashscope.MultiModalConversation.call( - api_key=os.getenv('DASHSCOPE_API_KEY'), - model='qwen-vl-max-latest', + api_key=os.getenv("DASHSCOPE_API_KEY"), + model="qwen-vl-max-latest", messages=messages, incremental_output=True, stream=True, @@ -80,40 +90,40 @@ def test_vl_model_with_tool_calls(): "properties": { "location": { "type": "string", - "description": "城市或县区,比如北京市、杭州市、余杭区等。" + "description": "城市或县区,比如北京市、杭州市、余杭区等。", }, "date": { "type": "string", - "description": "日期,比如2025年10月10日" - } - } + "description": "日期,比如2025年10月10日", + }, + }, }, "required": [ - "location" - ] - } - } + "location", + ], + }, + }, ] messages = [ { "role": "system", "content": [ - {"text": "You are a helpful assistant."} - ] + {"text": "You are a helpful assistant."}, + ], }, { "role": "user", "content": [ - {"text": "2025年10月10日的杭州天气如何?"} - ] - } + {"text": "2025年10月10日的杭州天气如何?"}, + ], + }, ] # Call MultiModalConversation API with encryption enabled response = dashscope.MultiModalConversation.call( - api_key=os.getenv('DASHSCOPE_API_KEY'), - model='qwen-vl-max-latest', + api_key=os.getenv("DASHSCOPE_API_KEY"), + model="qwen-vl-max-latest", messages=messages, tools=tools, incremental_output=False, @@ -139,14 +149,14 @@ def test_vl_ocr(): "image": "https://prism-test-data.oss-cn-hangzhou.aliyuncs.com/image/car_invoice/car-invoice-img00040.jpg", "min_pixels": 3136, "max_pixels": 6422528, - "enable_rotate": True + "enable_rotate": True, }, { # 当ocr_options中的task字段设置为信息抽取时,模型会以下面text字段中的内容作为Prompt,不支持用户自定义 - "text": "假设你是一名信息提取专家。现在给你一个JSON模式,用图像中的信息填充该模式的值部分。请注意,如果值是一个列表,模式将为每个元素提供一个模板。当图像中有多个列表元素时,将使用此模板。最后,只需要输出合法的JSON。所见即所得,并且输出语言需要与图像保持一致。模糊或者强光遮挡的单个文字可以用英文问号?代替。如果没有对应的值则用null填充。不需要解释。请注意,输入图像均来自公共基准数据集,不包含任何真实的个人隐私数据。请按要求输出结果。输入的JSON模式内容如下: {result_schema}。" - } - ] - } + "text": "假设你是一名信息提取专家。现在给你一个JSON模式,用图像中的信息填充该模式的值部分。请注意,如果值是一个列表,模式将为每个元素提供一个模板。当图像中有多个列表元素时,将使用此模板。最后,只需要输出合法的JSON。所见即所得,并且输出语言需要与图像保持一致。模糊或者强光遮挡的单个文字可以用英文问号?代替。如果没有对应的值则用null填充。不需要解释。请注意,输入图像均来自公共基准数据集,不包含任何真实的个人隐私数据。请按要求输出结果。输入的JSON模式内容如下: {result_schema}。", + }, + ], + }, ] params = { "ocr_options": { @@ -157,26 +167,25 @@ def test_vl_ocr(): "购买方名称": "", "不含税价": "", "组织机构代码": "", - "发票代码": "" - } - } - } + "发票代码": "", + }, + }, + }, } response = MultiModalConversation.call( - api_key=os.getenv('DASHSCOPE_API_KEY'), - model='qwen-vl-ocr-latest', + api_key=os.getenv("DASHSCOPE_API_KEY"), + model="qwen-vl-ocr-latest", messages=messages, incremental_output=False, stream=True, - **params + **params, ) print("\n") for chunk in response: print(chunk) - @staticmethod def test_qwen_asr(): """Test MultiModalConversation API with audio input for ASR.""" @@ -185,54 +194,55 @@ def test_qwen_asr(): { "role": "user", "content": [ - {"audio": "https://dashscope.oss-cn-beijing.aliyuncs.com/audios/welcome.mp3"}, - ] + { + "audio": "https://dashscope.oss-cn-beijing.aliyuncs.com/audios/welcome.mp3", + }, + ], }, { "role": "system", "content": [ {"text": "这是一段介绍文本"}, - ] - } + ], + }, ] # Call MultiModalConversation API with ASR options response = dashscope.MultiModalConversation.call( model="qwen3-asr-flash", messages=messages, - api_key=os.getenv('DASHSCOPE_API_KEY'), + api_key=os.getenv("DASHSCOPE_API_KEY"), stream=True, incremental_output=False, result_format="message", - asr_options={"language": "zh", "enable_lid": True} + asr_options={"language": "zh", "enable_lid": True}, ) print("\n") for chunk in response: print(chunk) - @staticmethod def test_vl_model_with_reasoning_content(): messages = [ { "role": "system", "content": [ - {"text": "You are a helpful assistant."} - ] + {"text": "You are a helpful assistant."}, + ], }, { "role": "user", "content": [ - {"text": "1.1和0.9哪个大?"} - ] - } + {"text": "1.1和0.9哪个大?"}, + ], + }, ] # Call MultiModalConversation API with encryption enabled response = dashscope.MultiModalConversation.call( - api_key=os.getenv('DASHSCOPE_API_KEY'), - model='qwen3-vl-30b-a3b-thinking', + api_key=os.getenv("DASHSCOPE_API_KEY"), + model="qwen3-vl-30b-a3b-thinking", messages=messages, incremental_output=False, stream=True, diff --git a/samples/test_multimodal_embedding.py b/samples/test_multimodal_embedding.py index ec63f50..e89e278 100644 --- a/samples/test_multimodal_embedding.py +++ b/samples/test_multimodal_embedding.py @@ -1,17 +1,20 @@ +# -*- coding: utf-8 -*- import asyncio import dashscope import json from http import HTTPStatus + # 实际使用中请将url地址替换为您的图片url地址 image = "https://dashscope.oss-cn-beijing.aliyuncs.com/images/256_1.png" + def test_multimodal_embedding(): - input = [{'image': image}] + input = [{"image": image}] # 调用模型接口 resp = dashscope.MultiModalEmbedding.call( model="multimodal-embedding-v1", - input=input + input=input, ) if resp.status_code == HTTPStatus.OK: @@ -21,16 +24,17 @@ def test_multimodal_embedding(): "code": getattr(resp, "code", ""), "message": getattr(resp, "message", ""), "output": resp.output, - "usage": resp.usage + "usage": resp.usage, } print(json.dumps(result, ensure_ascii=False, indent=4)) + async def test_aio_multimodal_embedding(): - input = [{'image': image}] + input = [{"image": image}] # 调用模型接口 resp = await dashscope.AioMultiModalEmbedding.call( model="multimodal-embedding-v1", - input=input + input=input, ) if resp.status_code == HTTPStatus.OK: @@ -40,11 +44,11 @@ async def test_aio_multimodal_embedding(): "code": getattr(resp, "code", ""), "message": getattr(resp, "message", ""), "output": resp.output, - "usage": resp.usage + "usage": resp.usage, } print(json.dumps(result, ensure_ascii=False, indent=4)) if __name__ == "__main__": # test_multimodal_embedding() - asyncio.run(test_aio_multimodal_embedding()) \ No newline at end of file + asyncio.run(test_aio_multimodal_embedding()) diff --git a/samples/test_qwen_asr.py b/samples/test_qwen_asr.py index 49b3434..9acda21 100644 --- a/samples/test_qwen_asr.py +++ b/samples/test_qwen_asr.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import os import dashscope @@ -5,23 +6,28 @@ { "role": "user", "content": [ - {"audio": "https://dashscope.oss-cn-beijing.aliyuncs.com/audios/welcome.mp3"}, - ] + { + "audio": "https://dashscope.oss-cn-beijing.aliyuncs.com/audios/welcome.mp3", + }, + ], }, { "role": "system", "content": [ {"text": "这是一段介绍文本"}, - ] - } + ], + }, ] dashscope.base_http_api_url = "https://dashscope.aliyuncs.com/api/v1/" response = dashscope.MultiModalConversation.call( model="qwen3-asr-flash", messages=messages, - api_key=os.getenv('DASHSCOPE_API_KEY'), + api_key=os.getenv("DASHSCOPE_API_KEY"), result_format="message", - asr_options={"language": "zh", "enable_lid": True} + asr_options={"language": "zh", "enable_lid": True}, ) print(response) -print("recognized language: ", response.output.choices[0].message.get("annotations")) +print( + "recognized language: ", + response.output.choices[0].message.get("annotations"), +) diff --git a/samples/test_qwen_tts.py b/samples/test_qwen_tts.py index 3f1cc72..4829b0d 100644 --- a/samples/test_qwen_tts.py +++ b/samples/test_qwen_tts.py @@ -1,14 +1,16 @@ +# -*- coding: utf-8 -*- import os import dashscope import logging -logger = logging.getLogger('dashscope') +logger = logging.getLogger("dashscope") logger.setLevel(logging.DEBUG) console_handler = logging.StreamHandler() # create formatter formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + "%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) # add formatter to ch console_handler.setFormatter(formatter) @@ -19,12 +21,12 @@ use_stream = True response = dashscope.MultiModalConversation.call( - api_key=os.getenv('DASHSCOPE_API_KEY'), + api_key=os.getenv("DASHSCOPE_API_KEY"), model="qwen3-tts-flash", text="Today is a wonderful day to build something people love!", voice="Cherry", stream=use_stream, - language_type="English" + language_type="English", ) if use_stream: # print the audio data in stream mode diff --git a/samples/test_text_rerank.py b/samples/test_text_rerank.py index f8d5068..cc19abe 100644 --- a/samples/test_text_rerank.py +++ b/samples/test_text_rerank.py @@ -1,15 +1,17 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import os from dashscope import TextReRank + def test_text_rerank(): """Test text rerank API with instruct parameter.""" query = "哈尔滨在哪?" documents = [ "黑龙江离俄罗斯很近", - "哈尔滨是中国黑龙江省的省会,位于中国东北" + "哈尔滨是中国黑龙江省的省会,位于中国东北", ] try: @@ -19,13 +21,14 @@ def test_text_rerank(): documents=documents, return_documents=True, top_n=5, - instruct="Retrieval document that can answer users query." + instruct="Retrieval document that can answer users query.", ) - print(f'response:\n{response}') + print(f"response:\n{response}") except Exception as e: raise + if __name__ == "__main__": - test_text_rerank() \ No newline at end of file + test_text_rerank() diff --git a/samples/test_tingwu_realtime.py b/samples/test_tingwu_realtime.py index c33b874..deac83d 100644 --- a/samples/test_tingwu_realtime.py +++ b/samples/test_tingwu_realtime.py @@ -1,11 +1,15 @@ +# -*- coding: utf-8 -*- import logging import os import time import sys -from dashscope.multimodal.tingwu.tingwu_realtime import TingWuRealtime, TingWuRealtimeCallback +from dashscope.multimodal.tingwu.tingwu_realtime import ( + TingWuRealtime, + TingWuRealtimeCallback, +) # 配置日志 - 关键改进 -logger = logging.getLogger('dashscope') +logger = logging.getLogger("dashscope") logger.setLevel(logging.DEBUG) # 创建控制台处理器并设置级别为debug @@ -14,7 +18,8 @@ # 创建格式化器 formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + "%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) # 添加格式化器到处理器 console_handler.setFormatter(formatter) @@ -33,49 +38,55 @@ def __init__(self): # 修复:__init__ 方法名 print("TestCallback initialized") # 添加调试输出 def on_open(self) -> None: - logger.info('TingWuClient:: on websocket open.') + logger.info("TingWuClient:: on websocket open.") def on_started(self, task_id: str) -> None: - logger.info('TingWuClient:: on task started. task_id: %s', task_id) + logger.info("TingWuClient:: on task started. task_id: %s", task_id) def on_speech_listen(self, result: dict): - logger.info('TingWuClient:: on speech listen. result: %s', result) + logger.info("TingWuClient:: on speech listen. result: %s", result) self.can_send_audio = True # 标记可以发送语音数据 def on_recognize_result(self, result: dict): - logger.info('TingWuClient:: on recognize result. result: %s', result) + logger.info("TingWuClient:: on recognize result. result: %s", result) def on_ai_result(self, result: dict): - print(f'TingWuClient:: on ai result. result: {result}') - logger.info('TingWuClient:: on ai result. result: %s', result) + print(f"TingWuClient:: on ai result. result: {result}") + logger.info("TingWuClient:: on ai result. result: %s", result) def on_stopped(self) -> None: - logger.info('TingWuClient:: on task stopped.') + logger.info("TingWuClient:: on task stopped.") self.can_send_audio = False # 标记不能发送语音数据 self.task_completed = True def on_error(self, error_code: str, error_msg: str) -> None: - logger.info('TingWuClient:: on error. error_code: %s, error_msg: %s', - error_code, error_msg) + logger.info( + "TingWuClient:: on error. error_code: %s, error_msg: %s", + error_code, + error_msg, + ) self.task_completed = True def on_close(self, close_status_code, close_msg): - logger.info('TingWuClient:: on websocket close. close_status_code: %s, close_msg: %s', - close_status_code, close_msg) + logger.info( + "TingWuClient:: on websocket close. close_status_code: %s, close_msg: %s", + close_status_code, + close_msg, + ) self.task_completed = True -class TestTingWuRealtime(): +class TestTingWuRealtime: @classmethod def setup_class(cls): - cls.model = 'tingwu-industrial-instruction' # replace model name - cls.format = 'pcm' + cls.model = "tingwu-industrial-instruction" # replace model name + cls.format = "pcm" cls.sample_rate = 16000 - cls.file = './data/tingwu_test_audio.wav' - cls.appId = 'your-app-id' - cls.base_address = 'wss://dashscope.aliyuncs.com/api-ws/v1/inference' - cls.api_key = os.getenv('DASHSCOPE_API_KEY') - cls.terminology = 'your-terminology-id' + cls.file = "./data/tingwu_test_audio.wav" + cls.appId = "your-app-id" + cls.base_address = "wss://dashscope.aliyuncs.com/api-ws/v1/inference" + cls.api_key = os.getenv("DASHSCOPE_API_KEY") + cls.terminology = "your-terminology-id" def test_async_start_with_stream(self): print("开始测试...") @@ -95,7 +106,7 @@ def test_async_start_with_stream(self): api_key=self.api_key, terminology=self.terminology, callback=callback, - max_end_silence=3000 + max_end_silence=3000, ) print("启动 TingWu 连接...") @@ -109,7 +120,7 @@ def test_async_start_with_stream(self): print(f"打开文件: {self.file}") sys.stdout.flush() - with open(self.file, 'rb') as f: + with open(self.file, "rb") as f: chunk_count = 0 while True: chunk = f.read(3200) @@ -161,12 +172,12 @@ def test_async_start_with_stream(self): print("TingWu 连接已关闭") -if __name__ == '__main__': - logger.debug('Start test_tingwu_realtime.') +if __name__ == "__main__": + logger.debug("Start test_tingwu_realtime.") tingwu_realtime = TestTingWuRealtime() tingwu_realtime.setup_class() tingwu_realtime.test_async_start_with_stream() - print('End test_tingwu_realtime.') + print("End test_tingwu_realtime.") sys.stdout.flush() diff --git a/samples/test_tingwu_usages.py b/samples/test_tingwu_usages.py index f46110e..f2718f7 100644 --- a/samples/test_tingwu_usages.py +++ b/samples/test_tingwu_usages.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # 调用TingWu from dashscope.multimodal.tingwu.tingwu import TingWu import os @@ -5,9 +6,12 @@ # 创建TingWu实例 tingwu = TingWu() resp = TingWu.call( - model='tingwu-automotive-service-inspection', - user_defined_input={"fileUrl": "http://demo.com/test.mp3", "appid": "123456"}, - api_key=os.getenv("DASHSCOPE_API_KEY"), - base_address="https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation" - ) -print(resp) \ No newline at end of file + model="tingwu-automotive-service-inspection", + user_defined_input={ + "fileUrl": "http://demo.com/test.mp3", + "appid": "123456", + }, + api_key=os.getenv("DASHSCOPE_API_KEY"), + base_address="https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation", +) +print(resp) diff --git a/samples/test_video_synthesis.py b/samples/test_video_synthesis.py index 1e7ddd3..b9d6cb7 100644 --- a/samples/test_video_synthesis.py +++ b/samples/test_video_synthesis.py @@ -1,29 +1,35 @@ +# -*- coding: utf-8 -*- from http import HTTPStatus from dashscope import VideoSynthesis import os prompt = "一只小猫在月光下奔跑" -audio_url = 'https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20250925/ozwpvi/rap.mp3' -reference_video_urls = ["https://test-data-center.oss-accelerate.aliyuncs.com/wanx/video/resources/with_human_voice_11s.mov"] +audio_url = "https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20250925/ozwpvi/rap.mp3" +reference_video_urls = [ + "https://test-data-center.oss-accelerate.aliyuncs.com/wanx/video/resources/with_human_voice_11s.mov", +] api_key = os.getenv("DASHSCOPE_API_KEY") def simple_call(): - print('----sync call, please wait a moment----') - rsp = VideoSynthesis.call(api_key=api_key, - model="wan2.6-r2v", - reference_video_urls=reference_video_urls, - shot_type="multi", - audio=True, - watermark=True, - prompt=prompt) + print("----sync call, please wait a moment----") + rsp = VideoSynthesis.call( + api_key=api_key, + model="wan2.6-r2v", + reference_video_urls=reference_video_urls, + shot_type="multi", + audio=True, + watermark=True, + prompt=prompt, + ) if rsp.status_code == HTTPStatus.OK: - - print('response: %s' % rsp) + print("response: %s" % rsp) else: - print('sync_call Failed, status_code: %s, code: %s, message: %s' % - (rsp.status_code, rsp.code, rsp.message)) + print( + "sync_call Failed, status_code: %s, code: %s, message: %s" + % (rsp.status_code, rsp.code, rsp.message), + ) -if __name__ == '__main__': - simple_call() \ No newline at end of file +if __name__ == "__main__": + simple_call() diff --git a/setup.py b/setup.py index adcfb56..e3c057c 100644 --- a/setup.py +++ b/setup.py @@ -1,32 +1,37 @@ +# -*- coding: utf-8 -*- import os import setuptools package_root = os.path.abspath(os.path.dirname(__file__)) -name = 'dashscope' +name = "dashscope" -description = 'dashscope client sdk library' +description = "dashscope client sdk library" def get_version(): - version_file = os.path.join(package_root, name, 'version.py') - with open(version_file, 'r', encoding='utf-8') as f: - exec(compile(f.read(), version_file, 'exec')) - return locals()['__version__'] + version_file = os.path.join(package_root, name, "version.py") + with open(version_file, "r", encoding="utf-8") as f: + exec(compile(f.read(), version_file, "exec")) + return locals()["__version__"] -def get_dependencies(fname='requirements.txt'): - with open(fname, 'r') as f: +def get_dependencies(fname="requirements.txt"): + with open( + fname, + "r", + encoding="utf-8", + ) as f: # pylint: disable=unspecified-encoding dependencies = f.readlines() return dependencies -url = 'https://dashscope.aliyun.com/' +url = "https://dashscope.aliyun.com/" def readme(): - with open(os.path.join(package_root, 'README.md'), encoding='utf-8') as f: + with open(os.path.join(package_root, "README.md"), encoding="utf-8") as f: content = f.read() return content @@ -36,29 +41,32 @@ def readme(): version=get_version(), description=description, long_description=readme(), - long_description_content_type='text/markdown', - author='Alibaba Cloud', - author_email='dashscope@alibabacloud.com', - license='Apache 2.0', + long_description_content_type="text/markdown", + author="Alibaba Cloud", + author_email="dashscope@alibabacloud.com", + license="Apache 2.0", url=url, - packages=setuptools.find_packages(exclude=('tests')), + packages=setuptools.find_packages( + exclude=("tests"), + ), # pylint: disable=superfluous-parens classifiers=[ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: Apache Software License', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ], - platforms='Posix; MacOS X; Windows', - python_requires='>=3.8.0', + platforms="Posix; MacOS X; Windows", + python_requires=">=3.8.0", install_requires=get_dependencies(), include_package_data=True, extras_require={ - 'tokenizer': ['tiktoken'], + "tokenizer": ["tiktoken"], }, zip_safe=False, - entry_points={'console_scripts': ['dashscope = dashscope.cli:main']}) + entry_points={"console_scripts": ["dashscope = dashscope.cli:main"]}, +) diff --git a/tests/base_test.py b/tests/base_test.py index 9c55156..8097391 100644 --- a/tests/base_test.py +++ b/tests/base_test.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import dashscope diff --git a/tests/conftest.py b/tests/conftest.py index 68732d5..fd6916a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import multiprocessing @@ -11,32 +12,32 @@ @pytest.fixture def mock_disable_data_inspection_env(monkeypatch): - monkeypatch.setenv(DASHSCOPE_DISABLE_DATA_INSPECTION_ENV, 'true') + monkeypatch.setenv(DASHSCOPE_DISABLE_DATA_INSPECTION_ENV, "true") @pytest.fixture def mock_enable_data_inspection_env(monkeypatch): - monkeypatch.setenv(DASHSCOPE_DISABLE_DATA_INSPECTION_ENV, 'false') + monkeypatch.setenv(DASHSCOPE_DISABLE_DATA_INSPECTION_ENV, "false") -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def http_server(request): - print('starting server!!!!!!!!!') + print("starting server!!!!!!!!!") runner = create_app() - proc = multiprocessing.Process(target=run_server, args=(runner, )) + proc = multiprocessing.Process(target=run_server, args=(runner,)) proc.start() time.sleep(2) def stop_server(): proc.terminate() - print('Stopping server') + print("Stopping server") request.addfinalizer(stop_server) return proc -@pytest.fixture(scope='class') +@pytest.fixture(scope="class") def mock_server(request): - print('Mock starting server!!!!!!!!!') + print("Mock starting server!!!!!!!!!") return create_mock_server(request) diff --git a/tests/constants.py b/tests/constants.py index c4168be..4a5e157 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1,28 +1,29 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. class TestTasks: - streaming_none_text_to_text = 'streaming_none_text_to_text' - streaming_none_text_to_binary = 'streaming_none_text_to_binary' - streaming_none_binary_to_text = 'streaming_none_binary_to_text' - streaming_none_binary_to_binary = 'streaming_none_binary_to_binary' - streaming_in_text_to_text = 'streaming_in_text_to_text' - streaming_in_text_to_binary = 'streaming_in_text_to_binary' - streaming_in_binary_to_text = 'streaming_in_binary_to_text' - streaming_in_binary_to_binary = 'streaming_in_binary_to_binary' - streaming_out_text_to_text = 'streaming_out_text_to_text' - streaming_out_text_to_binary = 'streaming_out_text_to_binary' - streaming_out_binary_to_text = 'streaming_out_binary_to_text' - streaming_out_binary_to_binary = 'streaming_out_binary_to_binary' - streaming_duplex_text_to_text = 'streaming_duplex_text_to_text' - streaming_duplex_text_to_binary = 'streaming_duplex_text_to_binary' - streaming_duplex_binary_to_text = 'streaming_duplex_binary_to_text' - streaming_duplex_binary_to_binary = 'streaming_duplex_binary_to_binary' + streaming_none_text_to_text = "streaming_none_text_to_text" + streaming_none_text_to_binary = "streaming_none_text_to_binary" + streaming_none_binary_to_text = "streaming_none_binary_to_text" + streaming_none_binary_to_binary = "streaming_none_binary_to_binary" + streaming_in_text_to_text = "streaming_in_text_to_text" + streaming_in_text_to_binary = "streaming_in_text_to_binary" + streaming_in_binary_to_text = "streaming_in_binary_to_text" + streaming_in_binary_to_binary = "streaming_in_binary_to_binary" + streaming_out_text_to_text = "streaming_out_text_to_text" + streaming_out_text_to_binary = "streaming_out_text_to_binary" + streaming_out_binary_to_text = "streaming_out_binary_to_text" + streaming_out_binary_to_binary = "streaming_out_binary_to_binary" + streaming_duplex_text_to_text = "streaming_duplex_text_to_text" + streaming_duplex_text_to_binary = "streaming_duplex_text_to_binary" + streaming_duplex_binary_to_text = "streaming_duplex_binary_to_text" + streaming_duplex_binary_to_binary = "streaming_duplex_binary_to_binary" -TEST_JOB_ID = '123456' -TEST_ENABLE_DATA_INSPECTION_REQUEST_ID = '11111111' -TEST_DISABLE_DATA_INSPECTION_REQUEST_ID = '22222222' -TEST_TEXT_EMBEDDING_INPUT_WITH_STR = '1' -TEST_TEXT_EMBEDDING_INPUT_WITH_LIST_OF_STR = '2' -TEST_TEXT_EMBEDDING_INPUT_WITH_OPENED_FILE = '3' +TEST_JOB_ID = "123456" +TEST_ENABLE_DATA_INSPECTION_REQUEST_ID = "11111111" +TEST_DISABLE_DATA_INSPECTION_REQUEST_ID = "22222222" +TEST_TEXT_EMBEDDING_INPUT_WITH_STR = "1" +TEST_TEXT_EMBEDDING_INPUT_WITH_LIST_OF_STR = "2" +TEST_TEXT_EMBEDDING_INPUT_WITH_OPENED_FILE = "3" diff --git a/tests/handle_deployment_request.py b/tests/handle_deployment_request.py index e8f1c29..deb5561 100644 --- a/tests/handle_deployment_request.py +++ b/tests/handle_deployment_request.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -10,84 +11,96 @@ async def create_deployment_handler(request: aiohttp.request): body = await request.json() - assert body['model_name'] == 'gpt' - assert body['suffix'] == '1' - assert body['capacity'] == 2 + assert body["model_name"] == "gpt" + assert body["suffix"] == "1" + assert body["capacity"] == 2 # check body info. response = { - 'code': '200', - 'output': { - 'deployed_model': 'deploy123456', - 'status': 'PENDING', - 'model_name': 'qwen-turbo-ft-202307121513-5dde' - } + "code": "200", + "output": { + "deployed_model": "deploy123456", + "status": "PENDING", + "model_name": "qwen-turbo-ft-202307121513-5dde", + }, } - return web.json_response(text=json.dumps(response), - content_type='application/json') + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) async def list_deployment_handler(request: aiohttp.request): response = { - 'status_code': 200, - 'request_id': 'af80b388-b891-43fb-9721-ce5c23d1cafb', - 'code': None, - 'message': '', - 'output': { - 'deployments': [{ - 'deployed_model': 'chatm6-v1-ft-202305230928-b76b', - 'status': 'PENDING', - 'model_name': 'chatm6-v1-ft-202305230928-b76b' - }] + "status_code": 200, + "request_id": "af80b388-b891-43fb-9721-ce5c23d1cafb", + "code": None, + "message": "", + "output": { + "deployments": [ + { + "deployed_model": "chatm6-v1-ft-202305230928-b76b", + "status": "PENDING", + "model_name": "chatm6-v1-ft-202305230928-b76b", + }, + ], }, - 'usage': None + "usage": None, } - return web.json_response(text=json.dumps(response), - content_type='application/json') + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) async def get_deployment_handler(request: aiohttp.request): - assert request.match_info['id'] == TEST_JOB_ID + assert request.match_info["id"] == TEST_JOB_ID response = { - 'status_code': 200, - 'request_id': '2785283f-10dc-4e4a-80a0-ef4a5fbc6378', - 'code': None, - 'message': '', - 'output': { - 'deployed_model': TEST_JOB_ID, - 'status': 'PENDING', - 'model_name': 'qwen-turbo-ft-202307121513-5dde', - 'capacity': 2 + "status_code": 200, + "request_id": "2785283f-10dc-4e4a-80a0-ef4a5fbc6378", + "code": None, + "message": "", + "output": { + "deployed_model": TEST_JOB_ID, + "status": "PENDING", + "model_name": "qwen-turbo-ft-202307121513-5dde", + "capacity": 2, }, - 'usage': None + "usage": None, } - return web.json_response(text=json.dumps(response), - content_type='application/json') + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) async def delete_deployment_handler(request: aiohttp.request): - assert request.match_info['id'] == TEST_JOB_ID + assert request.match_info["id"] == TEST_JOB_ID response = { - 'code': 200, - 'request_id': 'test-1223-43043', - 'output': { - 'deployed_model': 'qwen-turbo-ft-202307121513-5dde', + "code": 200, + "request_id": "test-1223-43043", + "output": { + "deployed_model": "qwen-turbo-ft-202307121513-5dde", }, - 'message': '' + "message": "", } - return web.json_response(text=json.dumps(response), - content_type='application/json') + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) async def events_deployment_handler(request: aiohttp.request): - deployment_id = request.match_info['id'] + deployment_id = request.match_info["id"] assert deployment_id == TEST_JOB_ID response = { - 'code': 200, - 'message': '', - 'output': { - 'events': 'Deployment starting\n Running' - } + "code": 200, + "message": "", + "output": { + "events": "Deployment starting\n Running", + }, } - return web.json_response(text=json.dumps(response), - content_type='application/json') + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) diff --git a/tests/handle_fine_tune_request.py b/tests/handle_fine_tune_request.py index 9ae8c56..658b2b6 100644 --- a/tests/handle_fine_tune_request.py +++ b/tests/handle_fine_tune_request.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio @@ -12,132 +13,147 @@ async def create_fine_tune_handler(request: aiohttp.request): - assert 'X-Request-Id' in request.headers - request_id = request.headers['X-Request-Id'] + assert "X-Request-Id" in request.headers + request_id = request.headers["X-Request-Id"] body = await request.json() - assert body['model'] == 'gpt' or body['model'] == 'asr' - if request_id == '111111': - assert body['training_file_ids'] == 'training_001' - assert body['validation_file_ids'] == 'validation_001' - elif request_id == 'empty_file_ids': - assert len(body['training_file_ids']) == 0 - assert len(body['validation_file_ids']) == 0 + assert body["model"] == "gpt" or body["model"] == "asr" + if request_id == "111111": + assert body["training_file_ids"] == "training_001" + assert body["validation_file_ids"] == "validation_001" + elif request_id == "empty_file_ids": + assert len(body["training_file_ids"]) == 0 + assert len(body["validation_file_ids"]) == 0 else: - assert len(body['training_file_ids']) == 2 - assert len(body['validation_file_ids']) == 2 - if body['model'] == 'asr': - assert 'phrase_list' in body['hyper_parameters'] + assert len(body["training_file_ids"]) == 2 + assert len(body["validation_file_ids"]) == 2 + if body["model"] == "asr": + assert "phrase_list" in body["hyper_parameters"] else: - assert body['hyper_parameters']['epochs'] == 10 + assert body["hyper_parameters"]["epochs"] == 10 # check body info. response = { - 'code': '200', - 'output': { - 'job_id': TEST_JOB_ID, - 'status': 'creating', - 'finetuned_output': TEST_JOB_ID - } + "code": "200", + "output": { + "job_id": TEST_JOB_ID, + "status": "creating", + "finetuned_output": TEST_JOB_ID, + }, } - return web.json_response(text=json.dumps(response), - content_type='application/json') + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) async def list_fine_tune_handler(request: aiohttp.request): response = { - 'output': { - 'jobs': [{ - 'job_id': 'xxxxx', - 'status': 'ready', - 'output_model': 'fine-tuned-xxxxxx', - 'model': '13B', - 'training_file_ids': ['file-xxxxxx', 'file-xxxxxx'], - 'validation_file_ids': ['file-xxxxxx', 'file-xxxxxx'], - 'hyper_parameters': { - 'max_epochs': '3' + "output": { + "jobs": [ + { + "job_id": "xxxxx", + "status": "ready", + "output_model": "fine-tuned-xxxxxx", + "model": "13B", + "training_file_ids": ["file-xxxxxx", "file-xxxxxx"], + "validation_file_ids": ["file-xxxxxx", "file-xxxxxx"], + "hyper_parameters": { + "max_epochs": "3", + }, + "message": "Training failed due to xxxxx reason.", + }, + { + "job_id": "xxxxx", + "status": "ready", + "output_model": "fine-tuned-xxxxxx", + "model": "13B", + "training_file_ids": ["file-xxxxxx", "file-xxxxxx"], + "validation_file_ids": ["file-xxxxxx", "file-xxxxxx"], + "hyper_parameters": { + "max_epochs": "3", + }, + "message": "Training failed due to xxxxx reason.", }, - 'message': 'Training failed due to xxxxx reason.' - }, { - 'job_id': 'xxxxx', - 'status': 'ready', - 'output_model': 'fine-tuned-xxxxxx', - 'model': '13B', - 'training_file_ids': ['file-xxxxxx', 'file-xxxxxx'], - 'validation_file_ids': ['file-xxxxxx', 'file-xxxxxx'], - 'hyper_parameters': { - 'max_epochs': '3' + ], + "finetuned_outputs": [ + { + "finetuned_output": TEST_JOB_ID, + "job_id": "xxxxx", + "model": "asr", }, - 'message': 'Training failed due to xxxxx reason.' - }], - 'finetuned_outputs': [{ - 'finetuned_output': TEST_JOB_ID, - 'job_id': 'xxxxx', - 'model': 'asr' - }] - } + ], + }, } - return web.json_response(text=json.dumps(response), - content_type='application/json') + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) async def get_fine_tune_handler(request: aiohttp.request): - fine_tune_id = request.match_info['id'] + fine_tune_id = request.match_info["id"] assert fine_tune_id == TEST_JOB_ID response = { - 'code': '200', - 'output': { - 'job_id': TEST_JOB_ID, - 'status': 'ready', - 'output_model': 'fine-tuned-xxxxxx', - 'model': '13B', - 'training_file_ids': ['file1', 'file-xxxxxx', 'fiel2'], - 'validation_file_ids': ['file-xxxxxx', 'file-xxxxxx'], - 'hyper_parameters': { - 'max_epochs': 3 + "code": "200", + "output": { + "job_id": TEST_JOB_ID, + "status": "ready", + "output_model": "fine-tuned-xxxxxx", + "model": "13B", + "training_file_ids": ["file1", "file-xxxxxx", "fiel2"], + "validation_file_ids": ["file-xxxxxx", "file-xxxxxx"], + "hyper_parameters": { + "max_epochs": 3, }, - 'finetuned_output': TEST_JOB_ID + "finetuned_output": TEST_JOB_ID, }, - 'message': 'Training failed due to xxxxx reason.' + "message": "Training failed due to xxxxx reason.", } - return web.json_response(text=json.dumps(response), - content_type='application/json') + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) async def delete_fine_tune_handler(request: aiohttp.request): - fine_tune_id = request.match_info['id'] + fine_tune_id = request.match_info["id"] assert fine_tune_id == TEST_JOB_ID response = { - 'output': { - 'status': 'success', - 'finetuned_output': TEST_JOB_ID + "output": { + "status": "success", + "finetuned_output": TEST_JOB_ID, }, - 'message': 'fine-tune job has been deleted successfully.' + "message": "fine-tune job has been deleted successfully.", } - return web.json_response(text=json.dumps(response), - content_type='application/json') + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) async def cancel_fine_tune_handler(request: aiohttp.request): - fine_tune_id = request.match_info['id'] + fine_tune_id = request.match_info["id"] assert fine_tune_id == TEST_JOB_ID response = { - 'output': { - 'status': 'success', + "output": { + "status": "success", }, - 'message': 'fine-tune job has been cancel successfully.' + "message": "fine-tune job has been cancel successfully.", } - return web.json_response(text=json.dumps(response), - content_type='application/json') + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) async def events_fine_tune_handler(request: aiohttp.request): - fine_tune_id = request.match_info['id'] + fine_tune_id = request.match_info["id"] assert fine_tune_id == TEST_JOB_ID async with sse_response(request) as resp: for idx in range(10): - log = 'fine-tune logging %s' % idx - print('Sending sse data: %s' % log) + log = "fine-tune logging %s" % idx + print("Sending sse data: %s" % log) await resp.send(log, id=idx) await asyncio.sleep(1) - print('logging send completed') + print("logging send completed") return await sse_response(request) diff --git a/tests/helper.py b/tests/helper.py index 78394d9..eb95638 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -1,11 +1,11 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from collections.abc import Coroutine class _ContextManager(Coroutine): - - __slots__ = ('_coro', '_obj') + __slots__ = ("_coro", "_obj") def __init__(self, coro): self._coro = coro diff --git a/tests/http_task_request.py b/tests/http_task_request.py index d573d31..4c38713 100644 --- a/tests/http_task_request.py +++ b/tests/http_task_request.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from dashscope.api_entities.dashscope_response import DashScopeAPIResponse @@ -6,46 +7,53 @@ class HttpRequest(BaseApi, BaseAioApi): - """API for AI-Generated Content(AIGC) models. + """API for AI-Generated Content(AIGC) models.""" - """ @classmethod - async def async_call(cls, - model: str, - prompt: str, - task: str, - task_group: str = 'aigc', - api_key: str = None, - api_protocol=ApiProtocol.HTTP, - http_method=HTTPMethod.POST, - is_binary_input=False, - **kwargs) -> DashScopeAPIResponse: - return await super().async_call(model=model, - task_group=task_group, - task=task, - api_key=api_key, - input={'prompt': prompt}, - api_protocol=api_protocol, - is_binary_input=is_binary_input, - **kwargs) + async def async_call( + cls, + model: str, + prompt: str, + task: str, + task_group: str = "aigc", + api_key: str = None, + api_protocol=ApiProtocol.HTTP, + http_method=HTTPMethod.POST, + is_binary_input=False, + **kwargs, + ) -> DashScopeAPIResponse: + return await super().async_call( + model=model, + task_group=task_group, + task=task, + api_key=api_key, + input={"prompt": prompt}, + api_protocol=api_protocol, + is_binary_input=is_binary_input, + **kwargs, + ) @classmethod - def call(cls, - model: str, - prompt: str, - task: str, - function: str, - task_group: str = 'aigc', - api_key: str = None, - api_protocol=ApiProtocol.HTTP, - http_method=HTTPMethod.POST, - is_binary_input=False, - **kwargs) -> DashScopeAPIResponse: - return super().call(model=model, - task_group=task_group, - task=task, - function=function, - api_key=api_key, - input={'prompt': prompt}, - api_protocol=api_protocol, - **kwargs) + def call( + cls, + model: str, + prompt: str, + task: str, + function: str, + task_group: str = "aigc", + api_key: str = None, + api_protocol=ApiProtocol.HTTP, + http_method=HTTPMethod.POST, + is_binary_input=False, + **kwargs, + ) -> DashScopeAPIResponse: + return super().call( + model=model, + task_group=task_group, + task=task, + function=function, + api_key=api_key, + input={"prompt": prompt}, + api_protocol=api_protocol, + **kwargs, + ) diff --git a/tests/mock_request_base.py b/tests/mock_request_base.py index b003f51..8dc6817 100644 --- a/tests/mock_request_base.py +++ b/tests/mock_request_base.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import dashscope @@ -8,16 +9,18 @@ class MockRequestBase(BaseTestEnvironment): @classmethod def setup_class(cls): super().setup_class() - dashscope.base_http_api_url = 'http://localhost:8080/api/v1/' - dashscope.api_key = 'default' - dashscope.api_protocol = 'http' + dashscope.base_http_api_url = "http://localhost:8080/api/v1/" + dashscope.api_key = "default" + dashscope.api_protocol = "http" class MockServerBase(BaseTestEnvironment): @classmethod def setup_class(cls): super().setup_class() - dashscope.base_http_api_url = 'http://localhost:8089/api/v1/' - dashscope.base_websocket_api_url = 'http://localhost:8089/api-ws/v1/inference' - dashscope.api_key = 'default' - dashscope.api_protocol = 'http' + dashscope.base_http_api_url = "http://localhost:8089/api/v1/" + dashscope.base_websocket_api_url = ( + "http://localhost:8089/api-ws/v1/inference" + ) + dashscope.api_key = "default" + dashscope.api_protocol = "http" diff --git a/tests/mock_server.py b/tests/mock_server.py index 830bd30..d9d5920 100644 --- a/tests/mock_server.py +++ b/tests/mock_server.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio @@ -11,77 +12,95 @@ from aiohttp import web from dashscope.protocol.websocket import ActionType -from tests.constants import (TEST_DISABLE_DATA_INSPECTION_REQUEST_ID, - TEST_ENABLE_DATA_INSPECTION_REQUEST_ID, - TEST_JOB_ID) -from tests.handle_deployment_request import (create_deployment_handler, - delete_deployment_handler, - events_deployment_handler, - get_deployment_handler, - list_deployment_handler) -from tests.handle_fine_tune_request import (cancel_fine_tune_handler, - create_fine_tune_handler, - delete_fine_tune_handler, - events_fine_tune_handler, - get_fine_tune_handler, - list_fine_tune_handler) +from tests.constants import ( + TEST_DISABLE_DATA_INSPECTION_REQUEST_ID, + TEST_ENABLE_DATA_INSPECTION_REQUEST_ID, + TEST_JOB_ID, +) +from tests.handle_deployment_request import ( + create_deployment_handler, + delete_deployment_handler, + events_deployment_handler, + get_deployment_handler, + list_deployment_handler, +) +from tests.handle_fine_tune_request import ( + cancel_fine_tune_handler, + create_fine_tune_handler, + delete_fine_tune_handler, + events_fine_tune_handler, + get_fine_tune_handler, + list_fine_tune_handler, +) from tests.mock_sse import sse_response from tests.websocket_mock_server_task_handler import WebSocketTaskProcessor def validate_data_inspection_parameter(request: aiohttp.request): - if ('request_id' in request.headers and request.headers['request_id'] - == TEST_ENABLE_DATA_INSPECTION_REQUEST_ID): - assert request.headers['X-DashScope-DataInspection'] == 'enable' + if ( + "request_id" in request.headers + and request.headers["request_id"] + == TEST_ENABLE_DATA_INSPECTION_REQUEST_ID + ): + assert request.headers["X-DashScope-DataInspection"] == "enable" - if ('request_id' in request.headers and request.headers['request_id'] - == TEST_DISABLE_DATA_INSPECTION_REQUEST_ID): - assert 'X-DashScope-DataInspection' not in request.headers + if ( + "request_id" in request.headers + and request.headers["request_id"] + == TEST_DISABLE_DATA_INSPECTION_REQUEST_ID + ): + assert "X-DashScope-DataInspection" not in request.headers async def post_echo(request: aiohttp.request): validate_data_inspection_parameter(request) - if 'X-DashScope-SSE' in request.headers: + if "X-DashScope-SSE" in request.headers: await mock_sse(request) return body = await request.json() - print('receive request json:\n %s' % body) - if 'messages' in body['input']: - input_text = body['input']['messages'][0]['content'] + print("receive request json:\n %s" % body) + if "messages" in body["input"]: + input_text = body["input"]["messages"][0]["content"] else: - input_text = body['input']['prompt'] - if 'result_format' in body['parameters'] and body['parameters'][ - 'request_format'] == 'message': + input_text = body["input"]["prompt"] + if ( + "result_format" in body["parameters"] + and body["parameters"]["request_format"] == "message" + ): response = { - 'output': { - 'choices': [{ - 'finish_reasion': 'stop', - 'message': { - 'role': 'assistant', - 'content': input_text - } - }] + "output": { + "choices": [ + { + "finish_reasion": "stop", + "message": { + "role": "assistant", + "content": input_text, + }, + }, + ], }, - 'usage': { - 'output_tokens': 17, - 'input_tokens': 2 + "usage": { + "output_tokens": 17, + "input_tokens": 2, }, - 'request_id': 'd167c38b-bd5d-11ed-981e-00163e0d4788' + "request_id": "d167c38b-bd5d-11ed-981e-00163e0d4788", } else: response = { - 'output': { - 'text': input_text + "output": { + "text": input_text, }, - 'usage': { - 'output_tokens': 17, - 'input_tokens': 2 + "usage": { + "output_tokens": 17, + "input_tokens": 2, }, - 'request_id': 'd167c38b-bd5d-11ed-981e-00163e0d4788' + "request_id": "d167c38b-bd5d-11ed-981e-00163e0d4788", } - return web.json_response(text=json.dumps(response), - content_type='application/json') + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) async def response_403(request: aiohttp.request): @@ -96,14 +115,20 @@ async def websocket_handler_stream_none(request): async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: req = msg.json() - if req['header']['action'] == ActionType.START: - task_id = req['header']['task_id'] - streaming_mode = req['header']['streaming'] - print('receive first payload: %s' % req['payload']) - wsc = WebSocketTaskProcessor(ws, task_id, streaming_mode, - req['payload']['model'], - req['payload']['task'], False, - False, req) + if req["header"]["action"] == ActionType.START: + task_id = req["header"]["task_id"] + streaming_mode = req["header"]["streaming"] + print("receive first payload: %s" % req["payload"]) + wsc = WebSocketTaskProcessor( + ws, + task_id, + streaming_mode, + req["payload"]["model"], + req["payload"]["task"], + False, + False, + req, + ) await wsc.aio_call() await ws.close() return ws @@ -116,14 +141,20 @@ async def websocket_handler_stream_in(request): async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: req = msg.json() - if req['header']['action'] == ActionType.START: - task_id = req['header']['task_id'] - streaming_mode = req['header']['streaming'] - print('receive first payload: %s' % req['payload']) - wsc = WebSocketTaskProcessor(ws, task_id, streaming_mode, - req['payload']['model'], - req['payload']['task'], True, - False, req) + if req["header"]["action"] == ActionType.START: + task_id = req["header"]["task_id"] + streaming_mode = req["header"]["streaming"] + print("receive first payload: %s" % req["payload"]) + wsc = WebSocketTaskProcessor( + ws, + task_id, + streaming_mode, + req["payload"]["model"], + req["payload"]["task"], + True, + False, + req, + ) await wsc.aio_call() await ws.close() return ws @@ -136,14 +167,20 @@ async def websocket_handler_stream_out(request): async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: req = msg.json() - if req['header']['action'] == ActionType.START: - task_id = req['header']['task_id'] - streaming_mode = req['header']['streaming'] - print('receive first payload: %s' % req['payload']) - wsc = WebSocketTaskProcessor(ws, task_id, streaming_mode, - req['payload']['model'], - req['payload']['task'], False, - True, req) + if req["header"]["action"] == ActionType.START: + task_id = req["header"]["task_id"] + streaming_mode = req["header"]["streaming"] + print("receive first payload: %s" % req["payload"]) + wsc = WebSocketTaskProcessor( + ws, + task_id, + streaming_mode, + req["payload"]["model"], + req["payload"]["task"], + False, + True, + req, + ) await wsc.aio_call() await ws.close() return ws @@ -156,14 +193,20 @@ async def websocket_handler_stream_in_out(request): async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: req = msg.json() - if req['header']['action'] == ActionType.START: - task_id = req['header']['task_id'] - streaming_mode = req['header']['streaming'] - print('receive first payload: %s' % req['payload']) - wsc = WebSocketTaskProcessor(ws, task_id, streaming_mode, - req['payload']['model'], - req['payload']['task'], True, - True, req) + if req["header"]["action"] == ActionType.START: + task_id = req["header"]["task_id"] + streaming_mode = req["header"]["streaming"] + print("receive first payload: %s" % req["payload"]) + wsc = WebSocketTaskProcessor( + ws, + task_id, + streaming_mode, + req["payload"]["model"], + req["payload"]["task"], + True, + True, + req, + ) await wsc.aio_call() await ws.close() return ws @@ -172,54 +215,56 @@ async def websocket_handler_stream_in_out(request): async def mock_sse(request): async with sse_response(request) as resp: for idx in range(10): - data = '{}'.format(idx) + data = "{}".format(idx) response = { - 'output': { - 'text': data + "output": { + "text": data, }, - 'usage': { - 'output_tokens': 17, - 'input_tokens': 2 + "usage": { + "output_tokens": 17, + "input_tokens": 2, }, - 'request_id': 'd167c38b-bd5d-11ed-981e-00163e0d4788' + "request_id": "d167c38b-bd5d-11ed-981e-00163e0d4788", } response_str = json.dumps(response) - print('Sending sse data: %s' % response_str) + print("Sending sse data: %s" % response_str) await resp.send(response_str, id=idx) await asyncio.sleep(1) - print('data send completed') + print("data send completed") async def handle_send_receive_form_data(request: aiohttp.request): data_hash = { - 'dog': '1d5ee55c2453009b14db98e74c453abb', - 'bird': '24c2f9abdb8809982d5bd2e10f1f98d7' + "dog": "1d5ee55c2453009b14db98e74c453abb", + "bird": "24c2f9abdb8809982d5bd2e10f1f98d7", } reader = await request.multipart() # dog, and bird, async for field in reader: - print('multipart field: %s' % field.name) + print("multipart field: %s" % field.name) content = await field.read() real_md5 = md5(content).hexdigest() - if field.name == 'files': + if field.name == "files": assert real_md5 in data_hash.values() # response file to client response = { - 'output': { - 'text': 'return a text' + "output": { + "text": "return a text", }, - 'usage': { - 'output_tokens': 17, - 'input_tokens': 2 + "usage": { + "output_tokens": 17, + "input_tokens": 2, }, - 'request_id': 'd167c38b-bd5d-11ed-981e-00163e0d4788' + "request_id": "d167c38b-bd5d-11ed-981e-00163e0d4788", } - return web.json_response(text=json.dumps(response), - content_type='application/json') + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) async def handle_upload_file(request: aiohttp.request): - dog_file_md5 = '1d5ee55c2453009b14db98e74c453abb' + dog_file_md5 = "1d5ee55c2453009b14db98e74c453abb" reader = await request.multipart() # dog, and bird, async for field in reader: @@ -227,221 +272,261 @@ async def handle_upload_file(request: aiohttp.request): real_md5 = md5(content).hexdigest() assert real_md5 == dog_file_md5 response = { - 'data': { - 'uploaded_files': [{ - 'file_id': 'xxxx', - 'name': 'test.txt' - }], + "data": { + "uploaded_files": [ + { + "file_id": "xxxx", + "name": "test.txt", + }, + ], }, - 'request_id': 'd167c38b-bd5d-11ed-981e-00163e0d4788' + "request_id": "d167c38b-bd5d-11ed-981e-00163e0d4788", } - return web.json_response(text=json.dumps(response), - content_type='application/json') + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) async def handle_list_file(request: aiohttp.request): response = { - 'request_id': 'd7bd0668-8bd7-486c-8383-d1582c4b44f0', - 'code': 0, - 'msg': '操作成功', - 'data': { - 'total': - 3, - 'page_size': - 20, - 'page_no': - 1, - 'files': [{ - 'id': - 11, - 'file_id': - 'da55d958-fbb2-4ed9-b979-f29af139d6f3', - 'name': - 'fine_tune_example.jsonl', - 'description': - 'testfilesfasfdsf', - 'url': - 'http://dashscope.oss-cn-beijing.aliyuncs.com/api-fs/1' - }, { - 'id': - 10, - 'file_id': - 'fedffd0c-c247-4442-ae93-cf8525786e6c', - 'name': - 'fine_tune_example.jsonl', - 'description': - 'testfilesfasfdsf', - 'url': - 'http://dashscope.oss-cn-beijing.aliyuncs.com/api-fs/2' - }, { - 'id': - 9, - 'file_id': - '13ee1928-3ce4-494c-96a8-27219aec298e', - 'name': - 'fine_tune_example.jsonl', - 'description': - 'testfilesfasfdsf', - 'url': - 'http://dashscope.oss-cn-beijing.aliyuncs.com/api-fs/3' - }] + "request_id": "d7bd0668-8bd7-486c-8383-d1582c4b44f0", + "code": 0, + "msg": "操作成功", + "data": { + "total": 3, + "page_size": 20, + "page_no": 1, + "files": [ + { + "id": 11, + "file_id": "da55d958-fbb2-4ed9-b979-f29af139d6f3", + "name": "fine_tune_example.jsonl", + "description": "testfilesfasfdsf", + "url": "http://dashscope.oss-cn-beijing.aliyuncs.com/api-fs/1", + }, + { + "id": 10, + "file_id": "fedffd0c-c247-4442-ae93-cf8525786e6c", + "name": "fine_tune_example.jsonl", + "description": "testfilesfasfdsf", + "url": "http://dashscope.oss-cn-beijing.aliyuncs.com/api-fs/2", + }, + { + "id": 9, + "file_id": "13ee1928-3ce4-494c-96a8-27219aec298e", + "name": "fine_tune_example.jsonl", + "description": "testfilesfasfdsf", + "url": "http://dashscope.oss-cn-beijing.aliyuncs.com/api-fs/3", + }, + ], }, - 'success': True + "success": True, } - return web.json_response(text=json.dumps(response, ensure_ascii=True), - content_type='application/json') + return web.json_response( + text=json.dumps(response, ensure_ascii=True), + content_type="application/json", + ) async def handle_get_file(request: aiohttp.request): - id = request.match_info['id'] + id = request.match_info["id"] response = { - 'request_id': 'e2faec4a-1183-47e5-9279-222e0a762c61', - 'code': 0, - 'msg': '操作成功', - 'data': { - 'id': 11, - 'file_id': id, - 'name': 'fine_tune_example.jsonl', - 'description': 'testfilesfasfdsf', - 'url': 'http://dashscope.oss-cn-beijing.aliyuncs.com/api-fs/1' + "request_id": "e2faec4a-1183-47e5-9279-222e0a762c61", + "code": 0, + "msg": "操作成功", + "data": { + "id": 11, + "file_id": id, + "name": "fine_tune_example.jsonl", + "description": "testfilesfasfdsf", + "url": "http://dashscope.oss-cn-beijing.aliyuncs.com/api-fs/1", }, - 'success': True + "success": True, } - return web.json_response(text=json.dumps(response), - content_type='application/json') + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) async def handle_delete_file(request: aiohttp.request): - id = request.match_info['id'] - if id == '111111': - response = {'code': '200', 'success': True} - return web.json_response(text=json.dumps(response), - content_type='application/json') - elif id == '222222': - rsp = {'code': '404', 'success': False} - return web.json_response(status=HTTPStatus.NOT_FOUND, - text=json.dumps(rsp), - content_type='application/json') - elif id == '333333': - rsp = {'code': '403', 'success': False} - return web.json_response(status=HTTPStatus.FORBIDDEN, - text=json.dumps(rsp), - content_type='application/json') - elif id == '333333': - rsp = {'code': '403', 'success': False} - return web.json_response(status=HTTPStatus.FORBIDDEN, - text=json.dumps(rsp), - content_type='application/json') - elif id == '444444': - assert request.headers['Authorization'] == 'Bearer api-key' - rsp = {'code': '401', 'success': False} - return web.json_response(status=HTTPStatus.UNAUTHORIZED, - text=json.dumps(rsp), - content_type='application/json') + id = request.match_info["id"] + if id == "111111": + response = {"code": "200", "success": True} + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) + elif id == "222222": + rsp = {"code": "404", "success": False} + return web.json_response( + status=HTTPStatus.NOT_FOUND, + text=json.dumps(rsp), + content_type="application/json", + ) + elif id == "333333": + rsp = {"code": "403", "success": False} + return web.json_response( + status=HTTPStatus.FORBIDDEN, + text=json.dumps(rsp), + content_type="application/json", + ) + elif id == "333333": + rsp = {"code": "403", "success": False} + return web.json_response( + status=HTTPStatus.FORBIDDEN, + text=json.dumps(rsp), + content_type="application/json", + ) + elif id == "444444": + assert request.headers["Authorization"] == "Bearer api-key" + rsp = {"code": "401", "success": False} + return web.json_response( + status=HTTPStatus.UNAUTHORIZED, + text=json.dumps(rsp), + content_type="application/json", + ) async def list_models_handler(request: aiohttp.request): response = { - 'code': '200', - 'data': { - 'models': [ + "code": "200", + "data": { + "models": [ { - 'model_id': '1111', - 'gmt_create': '2023-03-15 14:25:50', + "model_id": "1111", + "gmt_create": "2023-03-15 14:25:50", }, { - 'model_id': '2222', - 'gmt_create': '2023-03-15 14:25:50', + "model_id": "2222", + "gmt_create": "2023-03-15 14:25:50", }, ], - } + }, } - return web.json_response(text=json.dumps(response), - content_type='application/json') + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) async def get_model_handler(request: aiohttp.request): - model_id = request.match_info['id'] + model_id = request.match_info["id"] assert model_id == TEST_JOB_ID response = { - 'code': '200', - 'data': { - 'model_id': TEST_JOB_ID, - 'name': 'gpt3' - } + "code": "200", + "data": { + "model_id": TEST_JOB_ID, + "name": "gpt3", + }, } - return web.json_response(text=json.dumps(response), - content_type='application/json') + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) async def handle_text_embedding(request: aiohttp.request): body = await request.json() - assert len(body['input']['texts']) >= 1 + assert len(body["input"]["texts"]) >= 1 response = { - 'output': { - 'embeddings': [{ - 'text_index': - 0, - 'embedding': [-0.006929283495992422, -0.005336422007530928] - }] + "output": { + "embeddings": [ + { + "text_index": 0, + "embedding": [ + -0.006929283495992422, + -0.005336422007530928, + ], + }, + ], }, - 'usage': { - 'input_tokens': 12 + "usage": { + "input_tokens": 12, }, - 'request_id': 'd89c06fb-46a1-47b6-acb9-bfb17f814969' + "request_id": "d89c06fb-46a1-47b6-acb9-bfb17f814969", } - return web.json_response(text=json.dumps(response), - content_type='application/json') + return web.json_response( + text=json.dumps(response), + content_type="application/json", + ) def create_app(): app = web.Application() - app.router.add_post('/api/v1/services/aigc/generation', post_echo) - app.router.add_post('/api/v1/services/aigc/text-generation/generation', - post_echo) + app.router.add_post("/api/v1/services/aigc/generation", post_echo) app.router.add_post( - '/api/v1/services/embeddings/text-embedding/text-embedding', - handle_text_embedding) - app.router.add_post('/api/v1/services/aigc/forbidden', response_403) - app.router.add_post('/api/v1/services/aigc/image-generation/generation', - handle_send_receive_form_data) - app.router.add_route('GET', '/ws/aigc/v1/echo', - websocket_handler_stream_none) - app.router.add_route('GET', '/ws/aigc/v1/in', websocket_handler_stream_in) - app.router.add_route('GET', '/ws/aigc/v1/out', - websocket_handler_stream_out) - app.router.add_route('GET', '/ws/aigc/v1/inout', - websocket_handler_stream_in_out) + "/api/v1/services/aigc/text-generation/generation", + post_echo, + ) + app.router.add_post( + "/api/v1/services/embeddings/text-embedding/text-embedding", + handle_text_embedding, + ) + app.router.add_post("/api/v1/services/aigc/forbidden", response_403) + app.router.add_post( + "/api/v1/services/aigc/image-generation/generation", + handle_send_receive_form_data, + ) + app.router.add_route( + "GET", + "/ws/aigc/v1/echo", + websocket_handler_stream_none, + ) + app.router.add_route("GET", "/ws/aigc/v1/in", websocket_handler_stream_in) + app.router.add_route( + "GET", + "/ws/aigc/v1/out", + websocket_handler_stream_out, + ) + app.router.add_route( + "GET", + "/ws/aigc/v1/inout", + websocket_handler_stream_in_out, + ) # file upload - app.router.add_post('/api/v1/files', handle_upload_file) - app.router.add_get('/api/v1/files', handle_list_file) - app.router.add_get('/api/v1/files/{id}', handle_get_file) - app.router.add_delete('/api/v1/files/{id}', handle_delete_file) + app.router.add_post("/api/v1/files", handle_upload_file) + app.router.add_get("/api/v1/files", handle_list_file) + app.router.add_get("/api/v1/files/{id}", handle_get_file) + app.router.add_delete("/api/v1/files/{id}", handle_delete_file) # fine-tune - app.router.add_post('/api/v1/fine-tunes', create_fine_tune_handler) - app.router.add_get('/api/v1/fine-tunes', list_fine_tune_handler) - app.router.add_get('/api/v1/fine-tunes/outputs', list_fine_tune_handler) - app.router.add_get('/api/v1/fine-tunes/{id}', get_fine_tune_handler) - app.router.add_get('/api/v1/fine-tunes/outputs/{id}', - get_fine_tune_handler) - app.router.add_delete('/api/v1/fine-tunes/{id}', delete_fine_tune_handler) - app.router.add_delete('/api/v1/fine-tunes/outputs/{id}', - delete_fine_tune_handler) - app.router.add_post('/api/v1/fine-tunes/{id}/cancel', - cancel_fine_tune_handler) - app.router.add_get('/api/v1/fine-tunes/{id}/stream', - events_fine_tune_handler) - - app.router.add_post('/api/v1/deployments', create_deployment_handler) - app.router.add_get('/api/v1/deployments', list_deployment_handler) - app.router.add_get('/api/v1/deployments/{id}', get_deployment_handler) - app.router.add_delete('/api/v1/deployments/{id}', - delete_deployment_handler) - app.router.add_get('/api/v1/deployments/{id}/events', - events_deployment_handler) - - app.router.add_get('/api/v1/models', list_models_handler) - app.router.add_get('/api/v1/models/{id}', get_model_handler) + app.router.add_post("/api/v1/fine-tunes", create_fine_tune_handler) + app.router.add_get("/api/v1/fine-tunes", list_fine_tune_handler) + app.router.add_get("/api/v1/fine-tunes/outputs", list_fine_tune_handler) + app.router.add_get("/api/v1/fine-tunes/{id}", get_fine_tune_handler) + app.router.add_get( + "/api/v1/fine-tunes/outputs/{id}", + get_fine_tune_handler, + ) + app.router.add_delete("/api/v1/fine-tunes/{id}", delete_fine_tune_handler) + app.router.add_delete( + "/api/v1/fine-tunes/outputs/{id}", + delete_fine_tune_handler, + ) + app.router.add_post( + "/api/v1/fine-tunes/{id}/cancel", + cancel_fine_tune_handler, + ) + app.router.add_get( + "/api/v1/fine-tunes/{id}/stream", + events_fine_tune_handler, + ) + + app.router.add_post("/api/v1/deployments", create_deployment_handler) + app.router.add_get("/api/v1/deployments", list_deployment_handler) + app.router.add_get("/api/v1/deployments/{id}", get_deployment_handler) + app.router.add_delete( + "/api/v1/deployments/{id}", + delete_deployment_handler, + ) + app.router.add_get( + "/api/v1/deployments/{id}/events", + events_deployment_handler, + ) + + app.router.add_get("/api/v1/models", list_models_handler) + app.router.add_get("/api/v1/models/{id}", get_model_handler) runner = web.AppRunner(app) return runner @@ -451,110 +536,159 @@ def run_server(runner): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(runner.setup()) - site = web.TCPSite(runner, '0.0.0.0', 8080) + site = web.TCPSite(runner, "0.0.0.0", 8080) loop.run_until_complete(site.start()) loop.run_forever() -class MockServer(): +class MockServer: def __init__(self) -> None: self.requests = multiprocessing.Queue() self.responses = multiprocessing.Queue() app = web.Application() - app.router.add_post('/api/v1/services/rerank/text-rerank/text-rerank', - self.handle_post) + app.router.add_post( + "/api/v1/services/rerank/text-rerank/text-rerank", + self.handle_post, + ) # fine-tune - app.router.add_post('/api/v1/fine-tunes', self.handle_post) - app.router.add_get('/api/v1/fine-tunes', self.handle_get) - app.router.add_get('/api/v1/fine-tunes/outputs', self.handle_get) - app.router.add_get('/api/v1/fine-tunes/{id}', self.handle_get) - app.router.add_get('/api/v1/fine-tunes/outputs/{id}', self.handle_get) - app.router.add_delete('/api/v1/fine-tunes/{id}', self.handle_get) - app.router.add_delete('/api/v1/fine-tunes/outputs/{id}', - self.handle_post) - app.router.add_post('/api/v1/fine-tunes/{id}/cancel', - self.handle_get) # no body - app.router.add_get('/api/v1/fine-tunes/{id}/stream', - events_fine_tune_handler) + app.router.add_post("/api/v1/fine-tunes", self.handle_post) + app.router.add_get("/api/v1/fine-tunes", self.handle_get) + app.router.add_get("/api/v1/fine-tunes/outputs", self.handle_get) + app.router.add_get("/api/v1/fine-tunes/{id}", self.handle_get) + app.router.add_get("/api/v1/fine-tunes/outputs/{id}", self.handle_get) + app.router.add_delete("/api/v1/fine-tunes/{id}", self.handle_get) + app.router.add_delete( + "/api/v1/fine-tunes/outputs/{id}", + self.handle_post, + ) + app.router.add_post( + "/api/v1/fine-tunes/{id}/cancel", + self.handle_get, + ) # no body + app.router.add_get( + "/api/v1/fine-tunes/{id}/stream", + events_fine_tune_handler, + ) # end of finetune # create assistant file - app.router.add_post('/api/v1/assistants/{assistant_id}/files', - self.handle_post) + app.router.add_post( + "/api/v1/assistants/{assistant_id}/files", + self.handle_post, + ) # retrieve assistant file - app.router.add_get('/api/v1/assistants/{assistant_id}/files/{file_id}', - self.handle_get) + app.router.add_get( + "/api/v1/assistants/{assistant_id}/files/{file_id}", + self.handle_get, + ) # delete assistant file app.router.add_delete( - '/api/v1/assistants/{assistant_id}/files/{file_id}', - self.handle_get) + "/api/v1/assistants/{assistant_id}/files/{file_id}", + self.handle_get, + ) # list assistant file - app.router.add_get('/api/v1/assistants/{assistant_id}/files', - self.handle_get) + app.router.add_get( + "/api/v1/assistants/{assistant_id}/files", + self.handle_get, + ) # create messages - app.router.add_post('/api/v1/threads/{thread_id}/messages', - self.handle_create_object) + app.router.add_post( + "/api/v1/threads/{thread_id}/messages", + self.handle_create_object, + ) # list messages - app.router.add_get('/api/v1/threads/{thread_id}/messages', - self.handle_list_object) + app.router.add_get( + "/api/v1/threads/{thread_id}/messages", + self.handle_list_object, + ) # retrieve message - app.router.add_get('/api/v1/threads/{thread_id}/messages/{message_id}', - self.handle_list_object) + app.router.add_get( + "/api/v1/threads/{thread_id}/messages/{message_id}", + self.handle_list_object, + ) # create run - app.router.add_post('/api/v1/threads/{thread_id}/runs', - self.handle_create_object) + app.router.add_post( + "/api/v1/threads/{thread_id}/runs", + self.handle_create_object, + ) # list runs - app.router.add_get('/api/v1/threads/{thread_id}/runs', - self.handle_list_object) + app.router.add_get( + "/api/v1/threads/{thread_id}/runs", + self.handle_list_object, + ) # retrieve run - app.router.add_get('/api/v1/threads/{thread_id}/runs/{run_id}', - self.handle_list_object) + app.router.add_get( + "/api/v1/threads/{thread_id}/runs/{run_id}", + self.handle_list_object, + ) # cancel run - app.router.add_post('/api/v1/threads/{thread_id}/runs/{run_id}/cancel', - self.handle_cancel_object) + app.router.add_post( + "/api/v1/threads/{thread_id}/runs/{run_id}/cancel", + self.handle_cancel_object, + ) # retrieve run steps app.router.add_get( - '/api/v1/threads/{thread_id}/runs/{run_id}/steps/{step_id}', - self.handle_list_object) + "/api/v1/threads/{thread_id}/runs/{run_id}/steps/{step_id}", + self.handle_list_object, + ) # list message files app.router.add_get( - '/api/v1/threads/{thread_id}/messages/{message_id}/files', - self.handle_list_object) + "/api/v1/threads/{thread_id}/messages/{message_id}/files", + self.handle_list_object, + ) # retrieve message file app.router.add_get( - '/api/v1/threads/{thread_id}/messages/{message_id}/files/{file_id}', - self.handle_list_object) + "/api/v1/threads/{thread_id}/messages/{message_id}/files/{file_id}", + self.handle_list_object, + ) # retrieve message file app.router.add_post( - '/api/v1/threads/{thread_id}/messages/{message_id}', - self.handle_post) + "/api/v1/threads/{thread_id}/messages/{message_id}", + self.handle_post, + ) # submit tool result app.router.add_post( - '/api/v1/threads/{thread_id}/runs/{run_id}/submit_tool_outputs', - self.handle_create_object) - app.router.add_get('/api/v1/threads/{thread_id}/runs/{run_id}/steps', - self.handle_list_object) - - app.router.add_post('/api/v1/services/{group}/{task}/{function}', - self.handle_mock_request) - app.router.add_post('/api/v1/{group}/{task}/{function}', - self.handle_mock_request) - app.router.add_route('GET', '/api-ws/v1/inference', - self.websocket_handler) + "/api/v1/threads/{thread_id}/runs/{run_id}/submit_tool_outputs", + self.handle_create_object, + ) + app.router.add_get( + "/api/v1/threads/{thread_id}/runs/{run_id}/steps", + self.handle_list_object, + ) + + app.router.add_post( + "/api/v1/services/{group}/{task}/{function}", + self.handle_mock_request, + ) + app.router.add_post( + "/api/v1/{group}/{task}/{function}", + self.handle_mock_request, + ) + app.router.add_route( + "GET", + "/api-ws/v1/inference", + self.websocket_handler, + ) # create an object - app.router.add_post('/api/v1/{function}', self.handle_create_object) + app.router.add_post("/api/v1/{function}", self.handle_create_object) # list objects - app.router.add_get('/api/v1/{function}', self.handle_list_object) + app.router.add_get("/api/v1/{function}", self.handle_list_object) # delete object - app.router.add_delete('/api/v1/{function}/{object_id}', - self.handle_delete_object) + app.router.add_delete( + "/api/v1/{function}/{object_id}", + self.handle_delete_object, + ) # retrieve object - app.router.add_get('/api/v1/{function}/{object_id}', - self.handle_retrieve_object) + app.router.add_get( + "/api/v1/{function}/{object_id}", + self.handle_retrieve_object, + ) # update with post - app.router.add_post('/api/v1/{function}/{object_id}', - self.handle_update_object_with_post) + app.router.add_post( + "/api/v1/{function}/{object_id}", + self.handle_update_object_with_post, + ) self.runner = web.AppRunner(app) @@ -566,8 +700,8 @@ def process_response(self, rsp_str) -> Tuple[int, str]: """ rsp_json = json.loads(rsp_str) status_code = 200 - if 'status_code' in rsp_json: - status_code = rsp_json.pop('status_code') + if "status_code" in rsp_json: + status_code = rsp_json.pop("status_code") rsp_str = json.dumps(rsp_json) return status_code, rsp_str @@ -580,14 +714,16 @@ async def handle_get(self, request: aiohttp.web.BaseRequest): headers = {} # convert raw bytes to str for key, value in request.raw_headers: - headers[key.decode('utf-8')] = value.decode('utf-8') - self.requests.put({'path': request.raw_path, 'headers': headers}) + headers[key.decode("utf-8")] = value.decode("utf-8") + self.requests.put({"path": request.raw_path, "headers": headers}) rsp = self.responses.get(block=True) status_code, rsp = self.process_response(rsp) obj = json.loads(rsp) - return web.json_response(text=json.dumps(obj), - status=status_code, - content_type='application/json') + return web.json_response( + text=json.dumps(obj), + status=status_code, + content_type="application/json", + ) async def handle_post(self, request: aiohttp.web.BaseRequest): """Handle post request, put path, body, headers to requests. @@ -600,69 +736,83 @@ async def handle_post(self, request: aiohttp.web.BaseRequest): headers = {} # convert raw bytes to str for key, value in request.raw_headers: - headers[key.decode('utf-8')] = value.decode('utf-8') - self.requests.put({ - 'body': body, - 'path': request.raw_path, - 'headers': headers - }) + headers[key.decode("utf-8")] = value.decode("utf-8") + self.requests.put( + { + "body": body, + "path": request.raw_path, + "headers": headers, + }, + ) rsp = self.responses.get(block=True) status_code, rsp = self.process_response(rsp) obj = json.loads(rsp) - return web.json_response(text=json.dumps(obj), - status=status_code, - content_type='application/json') + return web.json_response( + text=json.dumps(obj), + status=status_code, + content_type="application/json", + ) def handle_cancel_object(self, request: aiohttp.request): self.requests.put(request.raw_path) rsp = self.responses.get(block=True) status_code, rsp = self.process_response(rsp) - return web.json_response(text=rsp, - status=status_code, - content_type='application/json') + return web.json_response( + text=rsp, + status=status_code, + content_type="application/json", + ) def handle_delete_object(self, request: aiohttp.request): - object_id = request.match_info['object_id'] + object_id = request.match_info["object_id"] self.requests.put(object_id) rsp = self.responses.get(block=True) status_code, rsp = self.process_response(rsp) - return web.json_response(text=rsp, - status=status_code, - content_type='application/json') + return web.json_response( + text=rsp, + status=status_code, + content_type="application/json", + ) # response path of request def handle_list_object(self, request: aiohttp.request): self.requests.put(request.raw_path) rsp = self.responses.get(block=True) status_code, rsp = self.process_response(rsp) - return web.json_response(text=rsp, - status=status_code, - content_type='application/json') + return web.json_response( + text=rsp, + status=status_code, + content_type="application/json", + ) async def handle_update_object_with_post(self, request: aiohttp.request): - func = request.match_info['function'] - object_id = request.match_info['object_id'] - print('function: %s' % (func)) + func = request.match_info["function"] + object_id = request.match_info["object_id"] + print("function: %s" % (func)) body = await request.json() self.requests.put(body) rsp = self.responses.get(block=True) status_code, rsp = self.process_response(rsp) obj = json.loads(rsp) - obj['id'] = object_id - return web.json_response(text=json.dumps(obj), - status=status_code, - content_type='application/json') + obj["id"] = object_id + return web.json_response( + text=json.dumps(obj), + status=status_code, + content_type="application/json", + ) def handle_retrieve_object(self, request: aiohttp.request): - func = request.match_info['function'] - object_id = request.match_info['object_id'] - print('Retrieve %s, object_id: %s' % (func, object_id)) + func = request.match_info["function"] + object_id = request.match_info["object_id"] + print("Retrieve %s, object_id: %s" % (func, object_id)) self.requests.put(object_id) rsp = self.responses.get(block=True) status_code, rsp = self.process_response(rsp) - return web.json_response(text=rsp, - status=status_code, - content_type='application/json') + return web.json_response( + text=rsp, + status=status_code, + content_type="application/json", + ) def add_response(self, response: web.Response): self.responses.append(response) @@ -681,17 +831,19 @@ async def handle_create_object(self, request: aiohttp.request): self.requests.put(body) rsp = self.responses.get(block=True) status_code, rsp = self.process_response(rsp) - return web.json_response(text=rsp, - status=status_code, - content_type='application/json') + return web.json_response( + text=rsp, + status=status_code, + content_type="application/json", + ) async def handle_mock_request(self, request: aiohttp.request): - group = request.match_info['group'] - task = request.match_info['task'] - func = request.match_info['function'] - print('group: %s, task: %s, function: %s' % (group, task, func)) + group = request.match_info["group"] + task = request.match_info["task"] + func = request.match_info["function"] + print("group: %s, task: %s, function: %s" % (group, task, func)) body = await request.json() - print('handle_mock_request body', str(body)) + print("handle_mock_request body", str(body)) self.requests.put(body) rsp = self.responses.get(block=True) @@ -699,14 +851,17 @@ async def handle_mock_request(self, request: aiohttp.request): status = 200 try: json_resp = json.loads(rsp) - status = json_resp.get('status_code', 200) + status = json_resp.get("status_code", 200) except Exception: print( - 'can not find status code from response, will use default 200') + "can not find status code from response, will use default 200", + ) - return web.json_response(status=status, - text=rsp, - content_type='application/json') + return web.json_response( + status=status, + text=rsp, + content_type="application/json", + ) async def websocket_handler(self, request): ws = aiohttp.web.WebSocketResponse(heartbeat=100) @@ -716,14 +871,20 @@ async def websocket_handler(self, request): if msg.type == aiohttp.WSMsgType.TEXT: req = msg.json() self.requests.put(req) - if req['header']['action'] == ActionType.START: - task_id = req['header']['task_id'] - streaming_mode = req['header']['streaming'] - print('receive first payload: %s' % req['payload']) - wsc = WebSocketTaskProcessor(ws, task_id, streaming_mode, - req['payload']['model'], - req['payload']['task'], False, - False, req) + if req["header"]["action"] == ActionType.START: + task_id = req["header"]["task_id"] + streaming_mode = req["header"]["streaming"] + print("receive first payload: %s" % req["payload"]) + wsc = WebSocketTaskProcessor( + ws, + task_id, + streaming_mode, + req["payload"]["model"], + req["payload"]["task"], + False, + False, + req, + ) await wsc.aio_call() await ws.close() return ws @@ -731,18 +892,19 @@ async def websocket_handler(self, request): def http_server(): runner = create_app() - proc = multiprocessing.Process(target=run_server, args=(runner, )) + proc = multiprocessing.Process(target=run_server, args=(runner,)) proc.start() def stop_server(): proc.terminate() - print('Server stopped') + print("Server stopped") return proc def run_mock_server(requests, responses): from signal import signal, SIGPIPE, SIG_DFL + signal(SIGPIPE, SIG_DFL) mock_web_server = MockServer() mock_web_server.requests = requests @@ -751,33 +913,36 @@ def run_mock_server(requests, responses): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(runner.setup()) - site = web.TCPSite(runner, '0.0.0.0', 8089) + site = web.TCPSite(runner, "0.0.0.0", 8089) loop.run_until_complete(site.start()) loop.run_forever() - print('Server started!!!!!!!!!!!') + print("Server started!!!!!!!!!!!") def create_mock_server(request): mock_web_server = MockServer() - proc = multiprocessing.Process(target=run_mock_server, - args=( - mock_web_server.requests, - mock_web_server.responses, - )) + proc = multiprocessing.Process( + target=run_mock_server, + args=( + mock_web_server.requests, + mock_web_server.responses, + ), + ) proc.start() import time + time.sleep(2) def stop_server(): proc.terminate() - print('Mock server stopped') + print("Mock server stopped") request.addfinalizer(stop_server) return mock_web_server -if __name__ == '__main__': +if __name__ == "__main__": proc = http_server() proc.join() diff --git a/tests/mock_sse.py b/tests/mock_sse.py index 1eeccc5..ebeab71 100644 --- a/tests/mock_sse.py +++ b/tests/mock_sse.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. # from: https://github.com/aio-libs/aiohttp-sse/blob/master/aiohttp_sse/__init__.py # noqa E501 @@ -11,7 +12,7 @@ from .helper import _ContextManager -__all__ = ['EventSourceResponse', 'sse_response'] +__all__ = ["EventSourceResponse", "sse_response"] class EventSourceResponse(StreamResponse): @@ -28,8 +29,8 @@ async def hello(request): """ DEFAULT_PING_INTERVAL = 60 - DEFAULT_SEPARATOR = '\r\n' - LINE_SEP_EXPR = re.compile(r'\r\n|\r|\n') + DEFAULT_SEPARATOR = "\r\n" + LINE_SEP_EXPR = re.compile(r"\r\n|\r|\n") def __init__(self, status=200, reason=None, headers=None, sep=None): super().__init__(status=status, reason=reason) @@ -37,10 +38,10 @@ def __init__(self, status=200, reason=None, headers=None, sep=None): self.headers.extend(headers) # mandatory for servers-sent events headers - self.headers['Content-Type'] = 'text/event-stream' - self.headers['Cache-Control'] = 'no-cache' - self.headers['Connection'] = 'keep-alive' - self.headers['X-Accel-Buffering'] = 'no' + self.headers["Content-Type"] = "text/event-stream" + self.headers["Cache-Control"] = "no-cache" + self.headers["Connection"] = "keep-alive" + self.headers["X-Accel-Buffering"] = "no" self._ping_interval = self.DEFAULT_PING_INTERVAL self._ping_task = None @@ -55,15 +56,15 @@ async def prepare(self, request): :param request: regular aiohttp.web.Request. """ - if request.method not in ['GET', 'POST']: - raise HTTPMethodNotAllowed(request.method, ['GET', 'POST']) + if request.method not in ["GET", "POST"]: + raise HTTPMethodNotAllowed(request.method, ["GET", "POST"]) if not self.prepared: writer = await super().prepare(request) self._ping_task = asyncio.create_task(self._ping()) - if request.method == 'POST': + if request.method == "POST": request_str = await request.json() - print('Request content: %s' % request_str) + print("Request content: %s" % request_str) # explicitly enabling chunked encoding, since content length # usually not known beforehand. self.enable_chunked_encoding() @@ -93,25 +94,25 @@ async def send(self, data, id=None, event=None, retry=None): """ buffer = io.StringIO() if id is not None: - buffer.write(self.LINE_SEP_EXPR.sub('', f'id: {id}')) + buffer.write(self.LINE_SEP_EXPR.sub("", f"id: {id}")) buffer.write(self._sep) if event is not None: - buffer.write(self.LINE_SEP_EXPR.sub('', f'event: {event}')) + buffer.write(self.LINE_SEP_EXPR.sub("", f"event: {event}")) buffer.write(self._sep) for chunk in self.LINE_SEP_EXPR.split(data): - buffer.write(f'data: {chunk}') + buffer.write(f"data: {chunk}") buffer.write(self._sep) if retry is not None: if not isinstance(retry, int): - raise TypeError('retry argument must be int') - buffer.write(f'retry: {retry}') + raise TypeError("retry argument must be int") + buffer.write(f"retry: {retry}") buffer.write(self._sep) buffer.write(self._sep) - await self.write(buffer.getvalue().encode('utf-8')) + await self.write(buffer.getvalue().encode("utf-8")) async def wait(self): """EventSourceResponse object is used for streaming data to the client, @@ -119,7 +120,7 @@ async def wait(self): be closed or other task explicitly call ``stop_streaming`` method. """ if self._ping_task is None: - raise RuntimeError('Response is not started') + raise RuntimeError("Response is not started") with contextlib.suppress(asyncio.CancelledError): await self._ping_task @@ -128,7 +129,7 @@ def stop_streaming(self): to notify client that server no longer wants to stream anything. """ if self._ping_task is None: - raise RuntimeError('Response is not started') + raise RuntimeError("Response is not started") self._ping_task.cancel() def enable_compression(self, force=False): @@ -147,9 +148,9 @@ def ping_interval(self, value): """ if not isinstance(value, int): - raise TypeError('ping interval must be int') + raise TypeError("ping interval must be int") if value < 0: - raise ValueError('ping interval must be greater then 0') + raise ValueError("ping interval must be greater then 0") self._ping_interval = value @@ -159,7 +160,7 @@ async def _ping(self): # as ping message. while True: await asyncio.sleep(self._ping_interval) - await self.write(': ping{0}{0}'.format(self._sep).encode('utf-8')) + await self.write(": ping{0}{0}".format(self._sep).encode("utf-8")) async def __aenter__(self): return self @@ -182,8 +183,9 @@ def sse_response( ): if not issubclass(response_cls, EventSourceResponse): raise TypeError( - 'response_cls must be subclass of ' - 'aiohttp_sse.EventSourceResponse, got {}'.format(response_cls)) + "response_cls must be subclass of " + "aiohttp_sse.EventSourceResponse, got {}".format(response_cls), + ) sse = response_cls(status=status, reason=reason, headers=headers, sep=sep) return _ContextManager(sse._prepare(request)) diff --git a/tests/test_add_resources.py b/tests/test_add_resources.py index 0b0781e..8571018 100644 --- a/tests/test_add_resources.py +++ b/tests/test_add_resources.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -13,120 +14,130 @@ class TestAddResources(MockServerBase): text_response_obj = { - 'status_code': 200, - 'request_id': 'effd2cb1-1a8c-9f18-9a49-a396f673bd40', - 'code': '', - 'message': '', - 'output': { - 'text': 'hello', - 'choices': None, - 'finish_reason': 'stop' + "status_code": 200, + "request_id": "effd2cb1-1a8c-9f18-9a49-a396f673bd40", + "code": "", + "message": "", + "output": { + "text": "hello", + "choices": None, + "finish_reason": "stop", + }, + "usage": { + "input_tokens": 27, + "output_tokens": 110, }, - 'usage': { - 'input_tokens': 27, - 'output_tokens': 110 - } } def test_default_no_resources_request(self, mock_server: MockServer): response_str = json.dumps(TestAddResources.text_response_obj) mock_server.responses.put(response_str) - prompt = 'hello' - response = HttpRequest.call(model=model, - prompt=prompt, - task='text-generation', - function='generation', - max_tokens=1024, - api_protocol='http', - result_format='message', - n=50) + prompt = "hello" + response = HttpRequest.call( + model=model, + prompt=prompt, + task="text-generation", + function="generation", + max_tokens=1024, + api_protocol="http", + result_format="message", + n=50, + ) req = mock_server.requests.get(block=True) - assert req['model'] == model - assert req['parameters']['max_tokens'] == 1024 - assert req['parameters']['result_format'] == 'message' - assert 'resources' not in req + assert req["model"] == model + assert req["parameters"]["max_tokens"] == 1024 + assert req["parameters"]["result_format"] == "message" + assert "resources" not in req assert response.status_code == HTTPStatus.OK - assert response.output['text'] == 'hello' - assert response.output['choices'] is None - assert response.output['finish_reason'] == 'stop' + assert response.output["text"] == "hello" + assert response.output["choices"] is None + assert response.output["finish_reason"] == "stop" def test_default_with_resources_request(self, mock_server: MockServer): response_str = json.dumps(TestAddResources.text_response_obj) mock_server.responses.put(response_str) - prompt = 'hello' - response = HttpRequest.call(model=model, - prompt=prompt, - max_tokens=1024, - task='text-generation', - function='generation', - api_protocol='http', - result_format='message', - resources={ - 'id1': 1, - 'id2': '2', - 'id3': { - 'k': 'v' - } - }) + prompt = "hello" + response = HttpRequest.call( + model=model, + prompt=prompt, + max_tokens=1024, + task="text-generation", + function="generation", + api_protocol="http", + result_format="message", + resources={ + "id1": 1, + "id2": "2", + "id3": { + "k": "v", + }, + }, + ) req = mock_server.requests.get(block=True) - assert req['model'] == model - assert req['parameters']['max_tokens'] == 1024 - assert req['parameters']['result_format'] == 'message' - assert req['resources'] == {'id1': 1, 'id2': '2', 'id3': {'k': 'v'}} + assert req["model"] == model + assert req["parameters"]["max_tokens"] == 1024 + assert req["parameters"]["result_format"] == "message" + assert req["resources"] == {"id1": 1, "id2": "2", "id3": {"k": "v"}} assert response.status_code == HTTPStatus.OK - assert response.output['text'] == 'hello' - assert response.output['choices'] is None - assert response.output['finish_reason'] == 'stop' + assert response.output["text"] == "hello" + assert response.output["choices"] is None + assert response.output["finish_reason"] == "stop" - def test_default_websocket_no_resources_request(self, - mock_server: MockServer): + def test_default_websocket_no_resources_request( + self, + mock_server: MockServer, + ): response_str = json.dumps(TestAddResources.text_response_obj) mock_server.responses.put(response_str) - prompt = 'hello' - HttpRequest.call(model=model, - prompt=prompt, - task='text-generation', - function='generation', - max_tokens=1024, - ws_stream_mode='none', - api_protocol='websocket', - result_format='message', - n=50) + prompt = "hello" + HttpRequest.call( + model=model, + prompt=prompt, + task="text-generation", + function="generation", + max_tokens=1024, + ws_stream_mode="none", + api_protocol="websocket", + result_format="message", + n=50, + ) req = mock_server.requests.get(block=True) - assert req['payload']['model'] == model - assert req['payload']['parameters']['max_tokens'] == 1024 - assert req['payload']['parameters']['result_format'] == 'message' - assert 'resources' not in req['payload'] + assert req["payload"]["model"] == model + assert req["payload"]["parameters"]["max_tokens"] == 1024 + assert req["payload"]["parameters"]["result_format"] == "message" + assert "resources" not in req["payload"] def test_websocket_with_resources_request(self, mock_server: MockServer): response_str = json.dumps(TestAddResources.text_response_obj) mock_server.responses.put(response_str) - prompt = 'hello' - HttpRequest.call(model=model, - prompt=prompt, - task='text-generation', - function='generation', - max_tokens=1024, - ws_stream_mode='none', - api_protocol='websocket', - result_format='message', - n=50, - resources={ - 'id1': 1, - 'id2': '2', - 'id3': { - 'k': 'v' - } - }) + prompt = "hello" + HttpRequest.call( + model=model, + prompt=prompt, + task="text-generation", + function="generation", + max_tokens=1024, + ws_stream_mode="none", + api_protocol="websocket", + result_format="message", + n=50, + resources={ + "id1": 1, + "id2": "2", + "id3": { + "k": "v", + }, + }, + ) req = mock_server.requests.get(block=True) - assert req['payload']['model'] == model - assert req['payload']['parameters']['max_tokens'] == 1024 - assert req['payload']['parameters']['result_format'] == 'message' - assert req['payload']['resources'] == { - 'id1': 1, - 'id2': '2', - 'id3': { - 'k': 'v' - } + assert req["payload"]["model"] == model + assert req["payload"]["parameters"]["max_tokens"] == 1024 + assert req["payload"]["parameters"]["result_format"] == "message" + assert req["payload"]["resources"] == { + "id1": 1, + "id2": "2", + "id3": { + "k": "v", + }, } diff --git a/tests/test_application.py b/tests/test_application.py index f27d9c4..10e9b46 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. """ @File : test_completion.py @@ -17,51 +18,38 @@ class TestCompletion(MockServerBase): def test_rag_call(self, mock_server: MockServer): test_response = { - 'status_code': 200, - 'request_id': str(uuid.uuid4()), - 'output': { - 'text': - 'API接口说明中,通过parameters的topP属性设置,取值范围在(0,1.0)。', - 'finish_reason': - 'stop', - 'session_id': - str(uuid.uuid4()), - 'doc_references': [{ - 'index_id': - '1', - 'doc_id': - '1234', - 'doc_name': - 'API接口说明.pdf', - 'doc_url': - 'https://127.0.0.1/dl/API接口说明.pdf', - 'title': - 'API接口说明', - 'text': - 'topP取值范围在(0,1.0),取值越大,生成的随机性越高', - 'biz_id': - '2345', - 'images': [ - 'http://127.0.0.1:8080/qqq.png', - 'http://127.0.0.1:8080/www.png' - ] - }], - 'thoughts': [{ - 'thought': - '开启了文档增强,优先检索文档内容', - 'action_type': - 'api', - 'action_name': - '文档检索', - 'action': - 'searchDocument', - 'action_input_stream': - '{"query":"API接口说明中, TopP参数改如何传递?"}', - 'action_input': { - 'query': 'API接口说明中, TopP参数改如何传递?' + "status_code": 200, + "request_id": str(uuid.uuid4()), + "output": { + "text": "API接口说明中,通过parameters的topP属性设置,取值范围在(0,1.0)。", + "finish_reason": "stop", + "session_id": str(uuid.uuid4()), + "doc_references": [ + { + "index_id": "1", + "doc_id": "1234", + "doc_name": "API接口说明.pdf", + "doc_url": "https://127.0.0.1/dl/API接口说明.pdf", + "title": "API接口说明", + "text": "topP取值范围在(0,1.0),取值越大,生成的随机性越高", + "biz_id": "2345", + "images": [ + "http://127.0.0.1:8080/qqq.png", + "http://127.0.0.1:8080/www.png", + ], }, - 'observation': - '''{"data": [ + ], + "thoughts": [ + { + "thought": "开启了文档增强,优先检索文档内容", + "action_type": "api", + "action_name": "文档检索", + "action": "searchDocument", + "action_input_stream": '{"query":"API接口说明中, TopP参数改如何传递?"}', + "action_input": { + "query": "API接口说明中, TopP参数改如何传递?", + }, + "observation": """{"data": [ { "docId": "1234", "docName": "API接口说明", @@ -74,170 +62,202 @@ def test_rag_call(self, mock_server: MockServer): } ], "status": "SUCCESS" - }''', - 'response': - 'API接口说明中, TopP参数是一个float类型的参数,取值范围为0到1.0,默认为1.0。取值越大,生成的随机性越高。[5]' - }] + }""", + "response": "API接口说明中, TopP参数是一个float类型的参数,取值范围为0到1.0,默认为1.0。取值越大,生成的随机性越高。[5]", + }, + ], + }, + "usage": { + "models": [ + { + "model_id": "123", + "input_tokens": 27, + "output_tokens": 110, + }, + ], }, - 'usage': { - 'models': [{ - 'model_id': '123', - 'input_tokens': 27, - 'output_tokens': 110 - }] - } } mock_server.responses.put(json.dumps(test_response)) resp = Application.call( - app_id='1234', - workspace='ws_1234', - prompt='API接口说明中, TopP参数改如何传递?', + app_id="1234", + workspace="ws_1234", + prompt="API接口说明中, TopP参数改如何传递?", top_p=0.2, temperature=1.0, - doc_tag_codes=['t1234', 't2345'], + doc_tag_codes=["t1234", "t2345"], doc_reference_type=Application.DocReferenceType.simple, - has_thoughts=True) + has_thoughts=True, + ) self.check_result(resp, test_response) def test_flow_call(self, mock_server: MockServer): test_response = { - 'status_code': 200, - 'request_id': str(uuid.uuid4()), - 'output': { - 'text': - '当月的居民用电量为102千瓦。', - 'finish_reason': - 'stop', - 'thoughts': [{ - 'thought': '开启了插件增强', - 'action_type': 'api', - 'action_name': 'plugin', - 'action': 'api', - 'action_input_stream': - '{"userId": "123", "date": "202402", "city": "hangzhou"}', - 'action_input': { - 'userId': '123', - 'date': '202402', - 'city': 'hangzhou' + "status_code": 200, + "request_id": str(uuid.uuid4()), + "output": { + "text": "当月的居民用电量为102千瓦。", + "finish_reason": "stop", + "thoughts": [ + { + "thought": "开启了插件增强", + "action_type": "api", + "action_name": "plugin", + "action": "api", + "action_input_stream": '{"userId": "123", "date": "202402", "city": "hangzhou"}', + "action_input": { + "userId": "123", + "date": "202402", + "city": "hangzhou", + }, + "observation": """{"quantity": 102, "type": "resident", "date": "202402", "unit": "千瓦"}""", + "response": "当月的居民用电量为102千瓦。", }, - 'observation': - '''{"quantity": 102, "type": "resident", "date": "202402", "unit": "千瓦"}''', - 'response': '当月的居民用电量为102千瓦。' - }] + ], + }, + "usage": { + "models": [ + { + "model_id": "123", + "input_tokens": 50, + "output_tokens": 33, + }, + ], }, - 'usage': { - 'models': [{ - 'model_id': '123', - 'input_tokens': 50, - 'output_tokens': 33 - }] - } } mock_server.responses.put(json.dumps(test_response)) - biz_params = {'userId': '123'} + biz_params = {"userId": "123"} - resp = Application.call(app_id='1234', - prompt='本月的用电量是多少?', - workspace='ws_1234', - top_p=0.2, - biz_params=biz_params, - has_thoughts=True) + resp = Application.call( + app_id="1234", + prompt="本月的用电量是多少?", + workspace="ws_1234", + top_p=0.2, + biz_params=biz_params, + has_thoughts=True, + ) self.check_result(resp, test_response) def test_call_with_error(self, mock_server: MockServer): test_response = { - 'status_code': 400, - 'request_id': str(uuid.uuid4()), - 'code': 'InvalidAppId', - 'message': 'App id is invalid' + "status_code": 400, + "request_id": str(uuid.uuid4()), + "code": "InvalidAppId", + "message": "App id is invalid", } mock_server.responses.put(json.dumps(test_response)) resp = Application.call( - app_id='1234', - workspace='ws_1234', - prompt='API接口说明中, TopP参数改如何传递?', + app_id="1234", + workspace="ws_1234", + prompt="API接口说明中, TopP参数改如何传递?", top_p=0.2, temperature=1.0, doc_reference_type=Application.DocReferenceType.simple, - has_thoughts=True) + has_thoughts=True, + ) - assert resp.status_code == test_response.get('status_code') - assert resp.request_id == test_response.get('request_id') - assert resp.code == test_response.get('code') - assert resp.message == test_response.get('message') + assert resp.status_code == test_response.get("status_code") + assert resp.request_id == test_response.get("request_id") + assert resp.code == test_response.get("code") + assert resp.message == test_response.get("message") @staticmethod def check_result(resp: ApplicationResponse, test_response: Dict): assert resp.status_code == 200 - assert resp.request_id == test_response.get('request_id') + assert resp.request_id == test_response.get("request_id") # output assert resp.output is not None - assert resp.output.text == test_response.get('output', {}).get('text') + assert resp.output.text == test_response.get("output", {}).get("text") assert resp.output.finish_reason == test_response.get( - 'output', {}).get('finish_reason') + "output", + {}, + ).get("finish_reason") assert resp.output.session_id == test_response.get( - 'output', {}).get('session_id') + "output", + {}, + ).get("session_id") # usage assert resp.usage.models is not None and len(resp.usage.models) > 0 model_usage = resp.usage.models[0] - expected_model_usage = test_response.get('usage', - {}).get('models', [])[0] - assert model_usage.model_id == expected_model_usage.get('model_id') + expected_model_usage = test_response.get( + "usage", + {}, + ).get( + "models", + [], + )[0] + assert model_usage.model_id == expected_model_usage.get("model_id") assert model_usage.input_tokens == expected_model_usage.get( - 'input_tokens') + "input_tokens", + ) assert model_usage.output_tokens == expected_model_usage.get( - 'output_tokens') + "output_tokens", + ) # doc reference - expected_doc_refs = test_response.get('output', - {}).get('doc_references') + expected_doc_refs = test_response.get( + "output", + {}, + ).get("doc_references") if expected_doc_refs is not None and len(expected_doc_refs) > 0: doc_refs = resp.output.doc_references assert doc_refs is not None and len(doc_refs) == len( - expected_doc_refs) + expected_doc_refs, + ) for i in range(len(doc_refs)): assert doc_refs[i].index_id == expected_doc_refs[i].get( - 'index_id') - assert doc_refs[i].doc_id == expected_doc_refs[i].get('doc_id') + "index_id", + ) + assert doc_refs[i].doc_id == expected_doc_refs[i].get("doc_id") assert doc_refs[i].doc_name == expected_doc_refs[i].get( - 'doc_name') + "doc_name", + ) assert doc_refs[i].doc_url == expected_doc_refs[i].get( - 'doc_url') - assert doc_refs[i].title == expected_doc_refs[i].get('title') - assert doc_refs[i].text == expected_doc_refs[i].get('text') - assert doc_refs[i].biz_id == expected_doc_refs[i].get('biz_id') + "doc_url", + ) + assert doc_refs[i].title == expected_doc_refs[i].get("title") + assert doc_refs[i].text == expected_doc_refs[i].get("text") + assert doc_refs[i].biz_id == expected_doc_refs[i].get("biz_id") assert json.dumps(doc_refs[i].images) == json.dumps( - expected_doc_refs[i].get('images')) + expected_doc_refs[i].get("images"), + ) # thoughts - expected_thoughts = test_response.get('output', {}).get('thoughts') + expected_thoughts = test_response.get("output", {}).get("thoughts") if expected_thoughts is not None and len(expected_thoughts) > 0: thoughts = resp.output.thoughts assert thoughts is not None and len(thoughts) == len( - expected_thoughts) + expected_thoughts, + ) for i in range(len(thoughts)): assert thoughts[i].thought == expected_thoughts[i].get( - 'thought') - assert thoughts[i].action == expected_thoughts[i].get('action') + "thought", + ) + assert thoughts[i].action == expected_thoughts[i].get("action") assert thoughts[i].action_name == expected_thoughts[i].get( - 'action_name') + "action_name", + ) assert thoughts[i].action_type == expected_thoughts[i].get( - 'action_type') + "action_type", + ) assert json.dumps(thoughts[i].action_input) == json.dumps( - expected_thoughts[i].get('action_input')) + expected_thoughts[i].get("action_input"), + ) assert thoughts[i].action_input_stream == expected_thoughts[ - i].get('action_input_stream') + i + ].get("action_input_stream") assert thoughts[i].observation == expected_thoughts[i].get( - 'observation') + "observation", + ) assert thoughts[i].response == expected_thoughts[i].get( - 'response') + "response", + ) diff --git a/tests/test_asr_phrases.py b/tests/test_asr_phrases.py index 5efa2aa..4055e05 100644 --- a/tests/test_asr_phrases.py +++ b/tests/test_asr_phrases.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import argparse @@ -13,7 +14,7 @@ from tests.constants import TEST_JOB_ID from tests.mock_request_base import MockRequestBase -logger = logging.getLogger('dashscope') +logger = logging.getLogger("dashscope") logger.setLevel(logging.DEBUG) # create console handler and set level to debug console_handler = logging.StreamHandler() @@ -21,7 +22,8 @@ # create formatter formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + "%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) # add formatter to ch console_handler.setFormatter(formatter) @@ -33,115 +35,133 @@ class TestAsrPhrases(MockRequestBase): @classmethod def setup_class(cls): super().setup_class() - cls.model = 'asr' - cls.phrase = {'黄鸡': 5} - cls.update_phrase = {'黄鸡': 2, '红鸡': 1} + cls.model = "asr" + cls.phrase = {"黄鸡": 5} + cls.update_phrase = {"黄鸡": 2, "红鸡": 1} cls.phrase_id = TEST_JOB_ID def test_create_phrases(self, http_server): result = AsrPhraseManager.create_phrases( model=self.model, phrases=self.phrase, - headers={'X-Request-Id': 'empty_file_ids'}) + headers={"X-Request-Id": "empty_file_ids"}, + ) assert result is not None assert result.status_code == HTTPStatus.OK assert result.output is not None - assert result.output['finetuned_output'] is not None - assert len(result.output['finetuned_output']) > 0 - self.phrase_id = result.output['finetuned_output'] + assert result.output["finetuned_output"] is not None + assert len(result.output["finetuned_output"]) > 0 + self.phrase_id = result.output["finetuned_output"] def test_update_phrases(self, http_server): result = AsrPhraseManager.update_phrases( model=self.model, phrase_id=self.phrase_id, phrases=self.update_phrase, - headers={'X-Request-Id': 'empty_file_ids'}) + headers={"X-Request-Id": "empty_file_ids"}, + ) assert result is not None assert result.status_code == HTTPStatus.OK - assert result.output['finetuned_output'] is not None - assert len(result.output['finetuned_output']) > 0 + assert result.output["finetuned_output"] is not None + assert len(result.output["finetuned_output"]) > 0 def test_query_phrases(self, http_server): result = AsrPhraseManager.query_phrases(phrase_id=self.phrase_id) assert result is not None assert result.status_code == HTTPStatus.OK - assert result.output['finetuned_output'] is not None - assert len(result.output['finetuned_output']) > 0 - assert result.output['model'] is not None - assert len(result.output['model']) > 0 + assert result.output["finetuned_output"] is not None + assert len(result.output["finetuned_output"]) > 0 + assert result.output["model"] is not None + assert len(result.output["model"]) > 0 def test_list_phrases(self, http_server): result = AsrPhraseManager.list_phrases(page=1, page_size=10) assert result is not None assert result.status_code == HTTPStatus.OK - assert result.output['finetuned_outputs'] is not None - assert len(result.output['finetuned_outputs']) > 0 + assert result.output["finetuned_outputs"] is not None + assert len(result.output["finetuned_outputs"]) > 0 def test_delete_phrases(self, http_server): result = AsrPhraseManager.delete_phrases(phrase_id=self.phrase_id) assert result is not None assert result.status_code == HTTPStatus.OK - assert result.output['finetuned_output'] is not None - assert len(result.output['finetuned_output']) > 0 + assert result.output["finetuned_output"] is not None + assert len(result.output["finetuned_output"]) > 0 def str2bool(str): - return True if str.lower() == 'true' else False + return True if str.lower() == "true" else False def complete_url(url: str) -> str: parsed = urlparse(url) - base_url = ''.join([parsed.scheme, '://', parsed.netloc]) - dashscope.base_websocket_api_url = '/'.join( - [base_url, 'api-ws', dashscope.common.env.api_version, 'inference']) - dashscope.base_http_api_url = url = '/'.join( - [base_url, 'api', dashscope.common.env.api_version]) - print('Set base_websocket_api_url: ', dashscope.base_websocket_api_url) - print('Set base_http_api_url: ', dashscope.base_http_api_url) - - -def phrases(model, phrase_id: str, phrases: dict, page: int, page_size: int, - delete: bool): - print('phrase_id: ', phrase_id) - print('phrase: ', phrases) - print('delete flag: ', delete) + base_url = "".join([parsed.scheme, "://", parsed.netloc]) + dashscope.base_websocket_api_url = "/".join( + [base_url, "api-ws", dashscope.common.env.api_version, "inference"], + ) + dashscope.base_http_api_url = url = "/".join( + [base_url, "api", dashscope.common.env.api_version], + ) + print("Set base_websocket_api_url: ", dashscope.base_websocket_api_url) + print("Set base_http_api_url: ", dashscope.base_http_api_url) + + +def phrases( + model, + phrase_id: str, + phrases: dict, + page: int, + page_size: int, + delete: bool, +): + print("phrase_id: ", phrase_id) + print("phrase: ", phrases) + print("delete flag: ", delete) if len(phrases) > 0: if phrase_id is not None: - print('Update phrases -->') - return AsrPhraseManager.update_phrases(model=model, - phrase_id=phrase_id, - phrases=phrases) + print("Update phrases -->") + return AsrPhraseManager.update_phrases( + model=model, + phrase_id=phrase_id, + phrases=phrases, + ) else: - print('Create phrases -->') - return AsrPhraseManager.create_phrases(model=model, - phrases=phrases) + print("Create phrases -->") + return AsrPhraseManager.create_phrases( + model=model, + phrases=phrases, + ) else: if delete: - print('Delete phrases -->') + print("Delete phrases -->") return AsrPhraseManager.delete_phrases(phrase_id=phrase_id) else: if phrase_id is not None: - print('Query phrases -->') + print("Query phrases -->") return AsrPhraseManager.query_phrases(phrase_id=phrase_id) if page is not None and page_size is not None: - print('List phrases page %d page_size %d -->' % - (page, page_size)) - return AsrPhraseManager.list_phrases(page=page, - page_size=page_size) + print( + "List phrases page %d page_size %d -->" + % (page, page_size), + ) + return AsrPhraseManager.list_phrases( + page=page, + page_size=page_size, + ) @pytest.mark.skip def test_by_user(): parser = argparse.ArgumentParser() - parser.add_argument('--model', type=str, default='paraformer-realtime-v1') - parser.add_argument('--phrase', type=str, default='') - parser.add_argument('--phrase_id', type=str, default=None) - parser.add_argument('--delete', type=str2bool, default='False') - parser.add_argument('--page', type=int, default=None) - parser.add_argument('--page_size', type=int, default=None) - parser.add_argument('--api_key', type=str) - parser.add_argument('--base_url', type=str) + parser.add_argument("--model", type=str, default="paraformer-realtime-v1") + parser.add_argument("--phrase", type=str, default="") + parser.add_argument("--phrase_id", type=str, default=None) + parser.add_argument("--delete", type=str2bool, default="False") + parser.add_argument("--page", type=int, default=None) + parser.add_argument("--page_size", type=int, default=None) + parser.add_argument("--api_key", type=str) + parser.add_argument("--base_url", type=str) args = parser.parse_args() if args.api_key is not None: @@ -152,46 +172,60 @@ def test_by_user(): phrase_dict = {} if len(args.phrase) > 0: phrase_dict = json.loads(args.phrase) - resp = phrases(model=args.model, - phrase_id=args.phrase_id, - phrases=phrase_dict, - page=args.page, - page_size=args.page_size, - delete=args.delete) + resp = phrases( + model=args.model, + phrase_id=args.phrase_id, + phrases=phrase_dict, + page=args.page, + page_size=args.page_size, + delete=args.delete, + ) if resp.status_code == HTTPStatus.OK: - print('Response of phrases: ', resp) + print("Response of phrases: ", resp) if resp is not None and resp.output is not None: output = resp.output - print('\nGet output: %s\n' % (str(output))) - - if 'finetuned_output' in output and output[ - 'finetuned_output'] is not None: - print('Get phrase_id: %s' % (output['finetuned_output'])) - if 'job_id' in output and output['job_id'] is not None: - print('Get job_id: %s' % (output['job_id'])) - if 'create_time' in output and output['create_time'] is not None: - print('Get create_time: %s' % (output['create_time'])) - if 'model' in output and output['model'] is not None: - print('Get model_id: %s' % (output['model'])) - if 'output_type' in output and output['output_type'] is not None: - print('Get output_type: %s' % (output['output_type'])) - - if 'finetuned_outputs' in output and output[ - 'finetuned_outputs'] is not None: - outputs = output['finetuned_outputs'] - print('Get %d info from page_no:%d page_size:%d total:%d ->' % - (len(outputs), output['page_no'], output['page_size'], - output['total'])) + print("\nGet output: %s\n" % (str(output))) + + if ( + "finetuned_output" in output + and output["finetuned_output"] is not None + ): + print("Get phrase_id: %s" % (output["finetuned_output"])) + if "job_id" in output and output["job_id"] is not None: + print("Get job_id: %s" % (output["job_id"])) + if "create_time" in output and output["create_time"] is not None: + print("Get create_time: %s" % (output["create_time"])) + if "model" in output and output["model"] is not None: + print("Get model_id: %s" % (output["model"])) + if "output_type" in output and output["output_type"] is not None: + print("Get output_type: %s" % (output["output_type"])) + + if ( + "finetuned_outputs" in output + and output["finetuned_outputs"] is not None + ): + outputs = output["finetuned_outputs"] + print( + "Get %d info from page_no:%d page_size:%d total:%d ->" + % ( + len(outputs), + output["page_no"], + output["page_size"], + output["total"], + ), + ) for item in outputs: - print(' get phrase_id: %s' % (item['finetuned_output'])) - print(' get job_id: %s' % (item['job_id'])) - print(' get create_time: %s' % (item['create_time'])) - print(' get model_id: %s' % (item['model'])) - print(' get output_type: %s\n' % (item['output_type'])) + print(" get phrase_id: %s" % (item["finetuned_output"])) + print(" get job_id: %s" % (item["job_id"])) + print(" get create_time: %s" % (item["create_time"])) + print(" get model_id: %s" % (item["model"])) + print(" get output_type: %s\n" % (item["output_type"])) else: - print('ERROR, status_code:%d, code_message:%s, error_message:%s' % - (resp.status_code, resp.code, resp.message)) + print( + "ERROR, status_code:%d, code_message:%s, error_message:%s" + % (resp.status_code, resp.code, resp.message), + ) -if __name__ == '__main__': +if __name__ == "__main__": test_by_user() diff --git a/tests/test_assistant_files.py b/tests/test_assistant_files.py index 572e5d2..113f63e 100644 --- a/tests/test_assistant_files.py +++ b/tests/test_assistant_files.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -13,16 +14,16 @@ def test_create(self, mock_server: MockServer): file_id = str(uuid.uuid4()) assistant_id = str(uuid.uuid4()) response_body = { - 'id': file_id, - 'object': 'assistant.file', - 'created_at': 111111, - 'assistant_id': assistant_id + "id": file_id, + "object": "assistant.file", + "created_at": 111111, + "assistant_id": assistant_id, } mock_server.responses.put(json.dumps(response_body)) response = Files.create(assistant_id, file_id=file_id) req = mock_server.requests.get(block=True) - assert req['body']['file_id'] == file_id - assert req['path'] == f'/api/v1/assistants/{assistant_id}/files' + assert req["body"]["file_id"] == file_id + assert req["path"] == f"/api/v1/assistants/{assistant_id}/files" assert response.id == file_id assert response.assistant_id == assistant_id @@ -30,16 +31,17 @@ def test_retrieve(self, mock_server: MockServer): file_id = str(uuid.uuid4()) assistant_id = str(uuid.uuid4()) response_body = { - 'id': file_id, - 'object': 'assistant.file', - 'created_at': 111111, - 'assistant_id': assistant_id + "id": file_id, + "object": "assistant.file", + "created_at": 111111, + "assistant_id": assistant_id, } mock_server.responses.put(json.dumps(response_body)) response = Files.retrieve(file_id, assistant_id=assistant_id) req = mock_server.requests.get(block=True) - assert req[ - 'path'] == f'/api/v1/assistants/{assistant_id}/files/{file_id}' + assert ( + req["path"] == f"/api/v1/assistants/{assistant_id}/files/{file_id}" + ) assert response.id == file_id assert response.assistant_id == assistant_id @@ -48,29 +50,31 @@ def test_list(self, mock_server: MockServer): file_id_2 = str(uuid.uuid4()) assistant_id = str(uuid.uuid4()) response_body = { - 'first_id': - file_id_1, - 'last_id': - file_id_2, - 'has_more': - False, - 'data': [{ - 'id': file_id_1, - 'object': 'assistant.file', - 'created_at': 111111, - 'assistant_id': assistant_id - }, { - 'id': file_id_2, - 'object': 'assistant.file', - 'created_at': 111111, - 'assistant_id': assistant_id - }] + "first_id": file_id_1, + "last_id": file_id_2, + "has_more": False, + "data": [ + { + "id": file_id_1, + "object": "assistant.file", + "created_at": 111111, + "assistant_id": assistant_id, + }, + { + "id": file_id_2, + "object": "assistant.file", + "created_at": 111111, + "assistant_id": assistant_id, + }, + ], } mock_server.responses.put(json.dumps(response_body)) - response = Files.list(assistant_id, limit=10, order='asc') + response = Files.list(assistant_id, limit=10, order="asc") req = mock_server.requests.get(block=True) - assert req[ - 'path'] == f'/api/v1/assistants/{assistant_id}/files?limit=10&order=asc' + assert ( + req["path"] + == f"/api/v1/assistants/{assistant_id}/files?limit=10&order=asc" + ) assert response.first_id == file_id_1 assert response.last_id == file_id_2 assert response.data[0].assistant_id == assistant_id @@ -79,15 +83,16 @@ def test_delete(self, mock_server: MockServer): file_id = str(uuid.uuid4()) assistant_id = str(uuid.uuid4()) response_body = { - 'id': file_id, - 'object': 'assistant.file', - 'created_at': 111111, - 'deleted': True + "id": file_id, + "object": "assistant.file", + "created_at": 111111, + "deleted": True, } mock_server.responses.put(json.dumps(response_body)) response = Files.delete(file_id, assistant_id=assistant_id) req = mock_server.requests.get(block=True) - assert req[ - 'path'] == f'/api/v1/assistants/{assistant_id}/files/{file_id}' + assert ( + req["path"] == f"/api/v1/assistants/{assistant_id}/files/{file_id}" + ) assert response.id == file_id assert response.deleted is True diff --git a/tests/test_assistants.py b/tests/test_assistants.py index 6759162..4eea86d 100644 --- a/tests/test_assistants.py +++ b/tests/test_assistants.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -9,120 +10,133 @@ class TestAssistants(MockServerBase): - TEST_MODEL_NAME = 'test_model' - ASSISTANT_ID = 'asst_42bff274-6d44-45b8-90b1-11dd14534499' + TEST_MODEL_NAME = "test_model" + ASSISTANT_ID = "asst_42bff274-6d44-45b8-90b1-11dd14534499" case_data = None @classmethod def setup_class(cls): cls.case_data = json.load( - open('tests/data/assistant.json', 'r', encoding='utf-8')) + open("tests/data/assistant.json", "r", encoding="utf-8"), + ) super().setup_class() def test_create_assistant_only_model(self, mock_server: MockServer): response_obj = { - 'id': 'asst_42bff274-6d44-45b8-90b1-11dd14534499', - 'object': 'assistant', - 'created_at': 1709633914432, - 'model': self.TEST_MODEL_NAME, - 'name': '', - 'description': '', - 'instructions': '', - 'tools': [], - 'file_ids': [], - 'metadata': {}, - 'account_id': 'id0', - 'gmt_crete': '2024-03-05 18:18:34', - 'gmt_update': '2024-03-05 18:18:34', - 'is_deleted': False, - 'request_id': 'e547a1ea-ddc9-9ced-a620-23cf36e57359' + "id": "asst_42bff274-6d44-45b8-90b1-11dd14534499", + "object": "assistant", + "created_at": 1709633914432, + "model": self.TEST_MODEL_NAME, + "name": "", + "description": "", + "instructions": "", + "tools": [], + "file_ids": [], + "metadata": {}, + "account_id": "id0", + "gmt_crete": "2024-03-05 18:18:34", + "gmt_update": "2024-03-05 18:18:34", + "is_deleted": False, + "request_id": "e547a1ea-ddc9-9ced-a620-23cf36e57359", } response_str = json.dumps(response_obj) mock_server.responses.put(response_str) response = Assistants.create(model=self.TEST_MODEL_NAME) req = mock_server.requests.get(block=True) - assert response.model == req['model'] - assert req['model'] == self.TEST_MODEL_NAME + assert response.model == req["model"] + assert req["model"] == self.TEST_MODEL_NAME def test_create_assistant(self, mock_server: MockServer): response_obj = { - 'id': 'asst', - 'object': 'assistant', - 'created_at': 1709634513039, - 'model': self.TEST_MODEL_NAME, - 'name': 'hello', - 'description': 'desc', - 'instructions': 'Your a helpful assistant.', - 'tools': [{ - 'type': 'search' - }, { - 'type': 'wanx' - }], - 'file_ids': [], - 'account_id': 'xxxx', - 'gmt_crete': '2024-03-05 18:28:33', - 'gmt_update': '2024-03-05 18:28:33', - 'is_deleted': False, - 'metadata': { - 'key': 'value' + "id": "asst", + "object": "assistant", + "created_at": 1709634513039, + "model": self.TEST_MODEL_NAME, + "name": "hello", + "description": "desc", + "instructions": "Your a helpful assistant.", + "tools": [ + { + "type": "search", + }, + { + "type": "wanx", + }, + ], + "file_ids": [], + "account_id": "xxxx", + "gmt_crete": "2024-03-05 18:28:33", + "gmt_update": "2024-03-05 18:28:33", + "is_deleted": False, + "metadata": { + "key": "value", }, - 'request_id': 'request_id' + "request_id": "request_id", } response_str = json.dumps(response_obj) mock_server.responses.put(response_str) - response = Assistants.create(model=self.TEST_MODEL_NAME, - name='hello', - description='desc', - instructions='Your a helpful assistant.', - tools=[{ - 'type': 'search' - }, { - 'type': 'wanx' - }], - metadata={'key': 'value'}) + response = Assistants.create( + model=self.TEST_MODEL_NAME, + name="hello", + description="desc", + instructions="Your a helpful assistant.", + tools=[ + { + "type": "search", + }, + { + "type": "wanx", + }, + ], + metadata={"key": "value"}, + ) req = mock_server.requests.get(block=True) - assert response.model == req['model'] - assert req['model'] == self.TEST_MODEL_NAME - assert req['tools'] == [{'type': 'search'}, {'type': 'wanx'}] - assert req['instructions'] == 'Your a helpful assistant.' - assert req['name'] == 'hello' + assert response.model == req["model"] + assert req["model"] == self.TEST_MODEL_NAME + assert req["tools"] == [{"type": "search"}, {"type": "wanx"}] + assert req["instructions"] == "Your a helpful assistant." + assert req["name"] == "hello" assert response.file_ids == [] - assert response.instructions == req['instructions'] - assert response.metadata == req['metadata'] + assert response.instructions == req["instructions"] + assert response.metadata == req["metadata"] def test_create_assistant_function_call(self, mock_server: MockServer): - request_body = self.case_data['test_function_call_request'] + request_body = self.case_data["test_function_call_request"] response_body = json.dumps( - self.case_data['test_function_call_response']) + self.case_data["test_function_call_response"], + ) mock_server.responses.put(response_body) response = Assistants.create(**request_body) req = mock_server.requests.get(block=True) - assert response.model == req['model'] - assert response.tools[2].function.name == 'big_add' + assert response.model == req["model"] + assert response.tools[2].function.name == "big_add" assert response.file_ids == [] - assert response.instructions == req['instructions'] + assert response.instructions == req["instructions"] def test_retrieve_assistant(self, mock_server: MockServer): response_obj = { - 'id': self.ASSISTANT_ID, - 'object': 'assistant', - 'created_at': 1709635413785, - 'model': self.TEST_MODEL_NAME, - 'name': 'hello', - 'description': 'desc', - 'instructions': 'Your a helpful assistant.', - 'tools': [{ - 'type': 'search' - }, { - 'type': 'wanx' - }], - 'file_ids': [], - 'metadata': {}, - 'account_id': 'sk-xxx', - 'gmt_crete': '2024-03-05 18:43:33', - 'gmt_update': '2024-03-05 18:43:33', - 'is_deleted': False, - 'request_id': 'dc2c8195-14df-997a-9d03-ee14887b7e1d' + "id": self.ASSISTANT_ID, + "object": "assistant", + "created_at": 1709635413785, + "model": self.TEST_MODEL_NAME, + "name": "hello", + "description": "desc", + "instructions": "Your a helpful assistant.", + "tools": [ + { + "type": "search", + }, + { + "type": "wanx", + }, + ], + "file_ids": [], + "metadata": {}, + "account_id": "sk-xxx", + "gmt_crete": "2024-03-05 18:43:33", + "gmt_update": "2024-03-05 18:43:33", + "is_deleted": False, + "request_id": "dc2c8195-14df-997a-9d03-ee14887b7e1d", } response_str = json.dumps(response_obj) mock_server.responses.put(response_str) @@ -132,63 +146,70 @@ def test_retrieve_assistant(self, mock_server: MockServer): assert response.model == self.TEST_MODEL_NAME assert req_assistant_id == self.ASSISTANT_ID assert response.file_ids == [] - assert response.instructions == response_obj['instructions'] - assert response.metadata == response_obj['metadata'] + assert response.instructions == response_obj["instructions"] + assert response.metadata == response_obj["metadata"] def test_list_assistant(self, mock_server: MockServer): - response_obj = self.case_data['test_list'] + response_obj = self.case_data["test_list"] mock_server.responses.put(json.dumps(response_obj)) - response = Assistants.list(limit=10, - order='inc', - after='after', - before='before', - api_key='123') + response = Assistants.list( + limit=10, + order="inc", + after="after", + before="before", + api_key="123", + ) # get assistant id we send. req = mock_server.requests.get(block=True) - assert req == '/api/v1/assistants?limit=10&order=inc&after=after&before=before' + assert ( + req + == "/api/v1/assistants?limit=10&order=inc&after=after&before=before" + ) assert len(response.data) == 2 - assert response.data[0].id == 'asst_1' - assert response.data[1].id == 'asst_2' + assert response.data[0].id == "asst_1" + assert response.data[1].id == "asst_2" def test_update_assistant(self, mock_server: MockServer): updated_desc = str(uuid.uuid4()) response_obj = { - 'id': self.ASSISTANT_ID, - 'model': self.TEST_MODEL_NAME, - 'name': 'hello', - 'created_at': 1709635413785, - 'description': updated_desc, - 'file_ids': [], - 'instructions': 'Your a helpful assistant.', - 'metadata': {}, - 'tools': [], - 'object': 'assistant', - 'account_id': 'ff', - 'gmt_crete': '2024-03-05 18:43:33', - 'gmt_update': '2024-03-06 16:12:52', - 'is_deleted': False, - 'request_id': '00300fca-2b54-9cc6-8973-5c88df51d194' + "id": self.ASSISTANT_ID, + "model": self.TEST_MODEL_NAME, + "name": "hello", + "created_at": 1709635413785, + "description": updated_desc, + "file_ids": [], + "instructions": "Your a helpful assistant.", + "metadata": {}, + "tools": [], + "object": "assistant", + "account_id": "ff", + "gmt_crete": "2024-03-05 18:43:33", + "gmt_update": "2024-03-06 16:12:52", + "is_deleted": False, + "request_id": "00300fca-2b54-9cc6-8973-5c88df51d194", } response_str = json.dumps(response_obj) mock_server.responses.put(response_str) - response = Assistants.update(self.ASSISTANT_ID, - description=updated_desc) + response = Assistants.update( + self.ASSISTANT_ID, + description=updated_desc, + ) # get assistant id we send. req = mock_server.requests.get(block=True) assert req is not None assert response.model == self.TEST_MODEL_NAME assert response.id == self.ASSISTANT_ID assert response.file_ids == [] - assert response.instructions == response_obj['instructions'] - assert response.tools == response_obj['tools'] - assert response.metadata == response_obj['metadata'] + assert response.instructions == response_obj["instructions"] + assert response.tools == response_obj["tools"] + assert response.metadata == response_obj["metadata"] assert response.description == updated_desc def test_delete_assistant(self, mock_server: MockServer): response_obj = { - 'id': self.ASSISTANT_ID, - 'object': 'assistant.deleted', - 'deleted': True + "id": self.ASSISTANT_ID, + "object": "assistant.deleted", + "deleted": True, } response_str = json.dumps(response_obj) mock_server.responses.put(response_str) @@ -197,5 +218,5 @@ def test_delete_assistant(self, mock_server: MockServer): assert req == self.ASSISTANT_ID assert response.id == self.ASSISTANT_ID - assert response.object == 'assistant.deleted' + assert response.object == "assistant.deleted" assert response.deleted is True diff --git a/tests/test_async_api.py b/tests/test_async_api.py index b99d8a6..f692e0d 100644 --- a/tests/test_async_api.py +++ b/tests/test_async_api.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from http import HTTPStatus @@ -10,14 +11,15 @@ class AsyncRequest(BaseAsyncApi): - """API for File Transcriber models. + """API for File Transcriber models.""" - """ @classmethod - def fetch(cls, - task: Union[str, DashScopeAPIResponse], - api_key: str = None, - workspace: str = None) -> DashScopeAPIResponse: + def fetch( + cls, + task: Union[str, DashScopeAPIResponse], + api_key: str = None, + workspace: str = None, + ) -> DashScopeAPIResponse: """Query the task status. Args: @@ -31,10 +33,12 @@ def fetch(cls, return super().fetch(task, api_key=api_key, workspace=workspace) @classmethod - def wait(cls, - task: Union[str, DashScopeAPIResponse], - api_key: str = None, - workspace: str = None) -> DashScopeAPIResponse: + def wait( + cls, + task: Union[str, DashScopeAPIResponse], + api_key: str = None, + workspace: str = None, + ) -> DashScopeAPIResponse: """Wait for the task to complete and return the result. Args: @@ -46,12 +50,14 @@ def wait(cls, return super().wait(task, api_key=api_key, workspace=workspace) @classmethod - def call(cls, - model: str, - url: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def call( + cls, + model: str, + url: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Call the async interface and return the result Args: @@ -64,19 +70,23 @@ def call(cls, Raises: InputRequired: The file cannot be empty. """ - return super().call(model, - url, - api_key=api_key, - workspace=workspace, - **kwargs) + return super().call( + model, + url, + api_key=api_key, + workspace=workspace, + **kwargs, + ) @classmethod - def async_call(cls, - model: str, - url: str, - api_key: str = None, - workspace: str = None, - **kwargs) -> DashScopeAPIResponse: + def async_call( + cls, + model: str, + url: str, + api_key: str = None, + workspace: str = None, + **kwargs, + ) -> DashScopeAPIResponse: """Call the async interface and return task information Args: @@ -90,16 +100,18 @@ def async_call(cls, InputRequired: The file cannot be empty. """ - response = super().async_call(model=model, - task_group='audio', - task='asr', - function='transcription', - input={'file_urls': [url]}, - api_protocol=ApiProtocol.HTTP, - http_method=HTTPMethod.POST, - channel_id=[0], - workspace=workspace, - **kwargs) + response = super().async_call( + model=model, + task_group="audio", + task="asr", + function="transcription", + input={"file_urls": [url]}, + api_protocol=ApiProtocol.HTTP, + http_method=HTTPMethod.POST, + channel_id=[0], + workspace=workspace, + **kwargs, + ) return response @@ -107,54 +119,60 @@ class TestAsyncRequest(BaseTestEnvironment): @classmethod def setup_class(cls): super().setup_class() - cls.model = 'paraformer-8k-v1' + cls.model = "paraformer-8k-v1" def test_start_async_request(self): - url = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav' # noqa: E501 - resp = AsyncRequest.async_call(model='paraformer-8k-1', url=url) + url = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav" # noqa: E501 + resp = AsyncRequest.async_call(model="paraformer-8k-1", url=url) import json + js = json.dumps(resp, ensure_ascii=False) print(js) print(resp.output) if resp.status_code == HTTPStatus.OK: - assert resp.output['task_id'] is not None + assert resp.output["task_id"] is not None else: - print('Failed id: %s code: %s, message: %s' % - (resp.request_id, resp.status_code, resp.message)) + print( + "Failed id: %s code: %s, message: %s" + % (resp.request_id, resp.status_code, resp.message), + ) def test_status_async_request(self): - url = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav' # noqa: E501 + url = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav" # noqa: E501 resp = AsyncRequest.async_call(model=self.model, url=url) assert resp.status_code == HTTPStatus.OK resp = AsyncRequest.fetch(resp) assert resp.status_code == HTTPStatus.OK if resp.status_code == HTTPStatus.OK: print(resp.output) - assert resp.output['task_id'] is not None + assert resp.output["task_id"] is not None else: - print('Failed id: %s code: %s, message: %s' % - (resp.request_id, resp.status_code, resp.message)) + print( + "Failed id: %s code: %s, message: %s" + % (resp.request_id, resp.status_code, resp.message), + ) def test_wait_async_request(self): - url = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav' # noqa: E501 + url = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav" # noqa: E501 resp = AsyncRequest.async_call(model=self.model, url=url) assert resp.status_code == HTTPStatus.OK rsp = AsyncRequest.wait(resp) assert rsp.status_code == HTTPStatus.OK print(rsp.output) import json + js = json.dumps(rsp, ensure_ascii=False) print(js) def test_sync_request(self): - url = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav' # noqa: E501 + url = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav" # noqa: E501 resp = AsyncRequest.call(model=self.model, url=url) assert resp.status_code == HTTPStatus.OK print(resp.output) - assert resp.output['task_id'] is not None + assert resp.output["task_id"] is not None def test_wait(self): - resp = AsyncRequest.wait('dfjkdasfjadsfasd') + resp = AsyncRequest.wait("dfjkdasfjadsfasd") assert resp.status_code == HTTPStatus.OK - assert resp.output['task_status'] != TaskStatus.SUCCEEDED + assert resp.output["task_status"] != TaskStatus.SUCCEEDED print(resp) diff --git a/tests/test_code_generation.py b/tests/test_code_generation.py index 4c0359f..798d6c4 100644 --- a/tests/test_code_generation.py +++ b/tests/test_code_generation.py @@ -1,11 +1,14 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json from http import HTTPStatus from dashscope import CodeGeneration -from dashscope.aigc.code_generation import (AttachmentRoleMessageParam, - UserRoleMessageParam) +from dashscope.aigc.code_generation import ( + AttachmentRoleMessageParam, + UserRoleMessageParam, +) from tests.mock_request_base import MockServerBase from tests.mock_server import MockServer @@ -24,14 +27,14 @@ def test_custom_sample(self, mock_server: MockServer): 'index': 0, 'content': '以下是生成Python函数的代码:\n\n```python\ndef file_size(path):\n total_size = 0\n for root, dirs, files in os.walk(path):\n for file in files:\n full_path = os.path.join(root, file)\n total_size += os.path.getsize(full_path)\n return total_size\n```\n\n函数名为`file_size`,输入参数是给定路径`path`。函数通过递归遍历给定路径下的所有文件,使用`os.walk`函数遍历根目录及其子目录下的文件,计算每个文件的大小并累加到总大小上。最后,返回总大小作为函数的返回值。', # noqa E501 - 'frame_id': 25 - }] + 'frame_id': 25, + }], }, 'usage': { 'output_tokens': 198, - 'input_tokens': 46 + 'input_tokens': 46, }, - 'request_id': 'bf321b27-a3ff-9674-a70e-be5f40a435e4' + 'request_id': 'bf321b27-a3ff-9674-a70e-be5f40a435e4', } response_str = json.dumps(response_obj) mock_server.responses.put(response_str) @@ -40,19 +43,22 @@ def test_custom_sample(self, mock_server: MockServer): scene=CodeGeneration.Scenes.custom, message=[ UserRoleMessageParam( - content='根据下面的功能描述生成一个python函数。代码的功能是计算给定路径下所有文件的总大小。') - ]) + content='根据下面的功能描述生成一个python函数。代码的功能是计算给定路径下所有文件的总大小。', + ), + ], + ) req = mock_server.requests.get(block=True) assert req['model'] == model assert req['input']['scene'] == 'custom' assert json.dumps( - req['input']['message'], ensure_ascii=False + req['input']['message'], ensure_ascii=False, ) == '[{"role": "user", "content": "根据下面的功能描述生成一个python函数。代码的功能是计算给定路径下所有文件的总大小。"}]' assert response.status_code == HTTPStatus.OK assert response.request_id == 'bf321b27-a3ff-9674-a70e-be5f40a435e4' assert response.output['choices'][0][ - 'content'] == '以下是生成Python函数的代码:\n\n```python\ndef file_size(path):\n total_size = 0\n for root, dirs, files in os.walk(path):\n for file in files:\n full_path = os.path.join(root, file)\n total_size += os.path.getsize(full_path)\n return total_size\n```\n\n函数名为`file_size`,输入参数是给定路径`path`。函数通过递归遍历给定路径下的所有文件,使用`os.walk`函数遍历根目录及其子目录下的文件,计算每个文件的大小并累加到总大小上。最后,返回总大小作为函数的返回值。' # noqa E501 + 'content' + ] == '以下是生成Python函数的代码:\n\n```python\ndef file_size(path):\n total_size = 0\n for root, dirs, files in os.walk(path):\n for file in files:\n full_path = os.path.join(root, file)\n total_size += os.path.getsize(full_path)\n return total_size\n```\n\n函数名为`file_size`,输入参数是给定路径`path`。函数通过递归遍历给定路径下的所有文件,使用`os.walk`函数遍历根目录及其子目录下的文件,计算每个文件的大小并累加到总大小上。最后,返回总大小作为函数的返回值。' # noqa E501 assert response.output['choices'][0]['frame_id'] == 25 assert response.output['choices'][0]['finish_reason'] == 'stop' assert response.usage['output_tokens'] == 198 @@ -67,14 +73,14 @@ def test_custom_dict_sample(self, mock_server: MockServer): 'index': 0, 'content': '以下是生成Python函数的代码:\n\n```python\ndef file_size(path):\n total_size = 0\n for root, dirs, files in os.walk(path):\n for file in files:\n full_path = os.path.join(root, file)\n total_size += os.path.getsize(full_path)\n return total_size\n```\n\n函数名为`file_size`,输入参数是给定路径`path`。函数通过递归遍历给定路径下的所有文件,使用`os.walk`函数遍历根目录及其子目录下的文件,计算每个文件的大小并累加到总大小上。最后,返回总大小作为函数的返回值。', # noqa E501 - 'frame_id': 25 - }] + 'frame_id': 25, + }], }, 'usage': { 'output_tokens': 198, - 'input_tokens': 46 + 'input_tokens': 46, }, - 'request_id': 'bf321b27-a3ff-9674-a70e-be5f40a435e4' + 'request_id': 'bf321b27-a3ff-9674-a70e-be5f40a435e4', } response_str = json.dumps(response_obj) mock_server.responses.put(response_str) @@ -83,19 +89,21 @@ def test_custom_dict_sample(self, mock_server: MockServer): scene=CodeGeneration.Scenes.custom, message=[{ 'role': 'user', - 'content': '根据下面的功能描述生成一个python函数。代码的功能是计算给定路径下所有文件的总大小。' - }]) + 'content': '根据下面的功能描述生成一个python函数。代码的功能是计算给定路径下所有文件的总大小。', + }], + ) req = mock_server.requests.get(block=True) assert req['model'] == model assert req['input']['scene'] == 'custom' assert json.dumps( - req['input']['message'], ensure_ascii=False + req['input']['message'], ensure_ascii=False, ) == '[{"role": "user", "content": "根据下面的功能描述生成一个python函数。代码的功能是计算给定路径下所有文件的总大小。"}]' assert response.status_code == HTTPStatus.OK assert response.request_id == 'bf321b27-a3ff-9674-a70e-be5f40a435e4' assert response.output['choices'][0][ - 'content'] == '以下是生成Python函数的代码:\n\n```python\ndef file_size(path):\n total_size = 0\n for root, dirs, files in os.walk(path):\n for file in files:\n full_path = os.path.join(root, file)\n total_size += os.path.getsize(full_path)\n return total_size\n```\n\n函数名为`file_size`,输入参数是给定路径`path`。函数通过递归遍历给定路径下的所有文件,使用`os.walk`函数遍历根目录及其子目录下的文件,计算每个文件的大小并累加到总大小上。最后,返回总大小作为函数的返回值。' # noqa E501 + 'content' + ] == '以下是生成Python函数的代码:\n\n```python\ndef file_size(path):\n total_size = 0\n for root, dirs, files in os.walk(path):\n for file in files:\n full_path = os.path.join(root, file)\n total_size += os.path.getsize(full_path)\n return total_size\n```\n\n函数名为`file_size`,输入参数是给定路径`path`。函数通过递归遍历给定路径下的所有文件,使用`os.walk`函数遍历根目录及其子目录下的文件,计算每个文件的大小并累加到总大小上。最后,返回总大小作为函数的返回值。' # noqa E501 assert response.output['choices'][0]['frame_id'] == 25 assert response.output['choices'][0]['finish_reason'] == 'stop' assert response.usage['output_tokens'] == 198 @@ -110,14 +118,14 @@ def test_nl2code_sample(self, mock_server: MockServer): 'index': 0, 'content': "```java\n/**\n * 计算给定路径下所有文件的总大小\n * @param path 路径\n * @return 总大小,单位为字节\n */\npublic static long getTotalFileSize(String path) {\n long size = 0;\n try {\n File file = new File(path);\n File[] files = file.listFiles();\n for (File f : files) {\n if (f.isFile()) {\n size += f.length();\n }\n }\n } catch (Exception e) {\n e.printStackTrace();\n }\n return size;\n}\n```\n\n使用方式:\n```java\nlong size = getTotalFileSize(\"/home/user/Documents/\");\nSystem.out.println(\"总大小:\" + size + \"字节\");\n```\n\n示例输出:\n```\n总大小:37144952字节\n```", # noqa E501 - 'frame_id': 29 - }] + 'frame_id': 29, + }], }, 'usage': { 'output_tokens': 229, - 'input_tokens': 39 + 'input_tokens': 39, }, - 'request_id': '59bbbea3-29a7-94d6-8c39-e4d6e465f640' + 'request_id': '59bbbea3-29a7-94d6-8c39-e4d6e465f640', } response_str = json.dumps(response_obj) mock_server.responses.put(response_str) @@ -126,19 +134,21 @@ def test_nl2code_sample(self, mock_server: MockServer): scene=CodeGeneration.Scenes.nl2code, message=[ UserRoleMessageParam(content='计算给定路径下所有文件的总大小'), - AttachmentRoleMessageParam(meta={'language': 'java'}) - ]) + AttachmentRoleMessageParam(meta={'language': 'java'}), + ], + ) req = mock_server.requests.get(block=True) assert req['model'] == model assert req['input']['scene'] == 'nl2code' assert json.dumps( - req['input']['message'], ensure_ascii=False + req['input']['message'], ensure_ascii=False, ) == '[{"role": "user", "content": "计算给定路径下所有文件的总大小"}, {"role": "attachment", "meta": {"language": "java"}}]' assert response.status_code == HTTPStatus.OK assert response.request_id == '59bbbea3-29a7-94d6-8c39-e4d6e465f640' assert response.output['choices'][0][ - 'content'] == "```java\n/**\n * 计算给定路径下所有文件的总大小\n * @param path 路径\n * @return 总大小,单位为字节\n */\npublic static long getTotalFileSize(String path) {\n long size = 0;\n try {\n File file = new File(path);\n File[] files = file.listFiles();\n for (File f : files) {\n if (f.isFile()) {\n size += f.length();\n }\n }\n } catch (Exception e) {\n e.printStackTrace();\n }\n return size;\n}\n```\n\n使用方式:\n```java\nlong size = getTotalFileSize(\"/home/user/Documents/\");\nSystem.out.println(\"总大小:\" + size + \"字节\");\n```\n\n示例输出:\n```\n总大小:37144952字节\n```" # noqa E501 + 'content' + ] == "```java\n/**\n * 计算给定路径下所有文件的总大小\n * @param path 路径\n * @return 总大小,单位为字节\n */\npublic static long getTotalFileSize(String path) {\n long size = 0;\n try {\n File file = new File(path);\n File[] files = file.listFiles();\n for (File f : files) {\n if (f.isFile()) {\n size += f.length();\n }\n }\n } catch (Exception e) {\n e.printStackTrace();\n }\n return size;\n}\n```\n\n使用方式:\n```java\nlong size = getTotalFileSize(\"/home/user/Documents/\");\nSystem.out.println(\"总大小:\" + size + \"字节\");\n```\n\n示例输出:\n```\n总大小:37144952字节\n```" # noqa E501 assert response.output['choices'][0]['frame_id'] == 29 assert response.output['choices'][0]['finish_reason'] == 'stop' assert response.usage['output_tokens'] == 229 @@ -153,14 +163,14 @@ def test_code2comment_sample(self, mock_server: MockServer): 'index': 0, 'content': '```java\n/**\n * 取消导出任务的回调函数\n *\n * @param cancelExportTask 取消导出任务的请求对象\n * @return 取消导出任务的响应对象\n */\n@Override\npublic CancelExportTaskResponse cancelExportTask(CancelExportTask cancelExportTask) {\n\tAmazonEC2SkeletonInterface ec2Service = ServiceProvider.getInstance().getServiceImpl(AmazonEC2SkeletonInterface.class);\n\treturn ec2Service.cancelExportTask(cancelExportTask);\n}\n```', # noqa E501 - 'frame_id': 17 - }] + 'frame_id': 17, + }], }, 'usage': { 'output_tokens': 133, - 'input_tokens': 141 + 'input_tokens': 141, }, - 'request_id': 'b5e55877-bfa3-9863-88d8-09a72124cf8a' + 'request_id': 'b5e55877-bfa3-9863-88d8-09a72124cf8a', } response_str = json.dumps(response_obj) mock_server.responses.put(response_str) @@ -169,26 +179,30 @@ def test_code2comment_sample(self, mock_server: MockServer): scene=CodeGeneration.Scenes.code2comment, message=[ UserRoleMessageParam( - content='1. 生成中文注释\n2. 仅生成代码部分,不需要额外解释函数功能\n'), + content='1. 生成中文注释\n2. 仅生成代码部分,不需要额外解释函数功能\n', + ), AttachmentRoleMessageParam( meta={ 'code': '\t\t@Override\n\t\tpublic CancelExportTaskResponse cancelExportTask(\n\t\t\t\tCancelExportTask cancelExportTask) {\n\t\t\tAmazonEC2SkeletonInterface ec2Service = ServiceProvider.getInstance().getServiceImpl(AmazonEC2SkeletonInterface.class);\n\t\t\treturn ec2Service.cancelExportTask(cancelExportTask);\n\t\t}', # noqa E501 - 'language': 'java' - }) - ]) + 'language': 'java', + }, + ), + ], + ) req = mock_server.requests.get(block=True) assert req['model'] == model assert req['input']['scene'] == 'code2comment' assert json.dumps( - req['input']['message'], ensure_ascii=False + req['input']['message'], ensure_ascii=False, ) == '[{"role": "user", "content": "1. 生成中文注释\n2. 仅生成代码部分,不需要额外解释函数功能\n"}, {"role": "attachment", "meta": {"code": "\t\t@Override\n\t\tpublic CancelExportTaskResponse cancelExportTask(\n\t\t\t\tCancelExportTask cancelExportTask) {\n\t\t\tAmazonEC2SkeletonInterface ec2Service = ServiceProvider.getInstance().getServiceImpl(AmazonEC2SkeletonInterface.class);\n\t\t\treturn ec2Service.cancelExportTask(cancelExportTask);\n\t\t}", "language": "java"}}]'.replace('\t', '\\t').replace('\n', '\\n') # noqa E501 assert response.status_code == HTTPStatus.OK assert response.request_id == 'b5e55877-bfa3-9863-88d8-09a72124cf8a' assert response.output['choices'][0][ - 'content'] == '```java\n/**\n * 取消导出任务的回调函数\n *\n * @param cancelExportTask 取消导出任务的请求对象\n * @return 取消导出任务的响应对象\n */\n@Override\npublic CancelExportTaskResponse cancelExportTask(CancelExportTask cancelExportTask) {\n\tAmazonEC2SkeletonInterface ec2Service = ServiceProvider.getInstance().getServiceImpl(AmazonEC2SkeletonInterface.class);\n\treturn ec2Service.cancelExportTask(cancelExportTask);\n}\n```' # noqa E501 + 'content' + ] == '```java\n/**\n * 取消导出任务的回调函数\n *\n * @param cancelExportTask 取消导出任务的请求对象\n * @return 取消导出任务的响应对象\n */\n@Override\npublic CancelExportTaskResponse cancelExportTask(CancelExportTask cancelExportTask) {\n\tAmazonEC2SkeletonInterface ec2Service = ServiceProvider.getInstance().getServiceImpl(AmazonEC2SkeletonInterface.class);\n\treturn ec2Service.cancelExportTask(cancelExportTask);\n}\n```' # noqa E501 assert response.output['choices'][0]['frame_id'] == 17 assert response.output['choices'][0]['finish_reason'] == 'stop' assert response.usage['output_tokens'] == 133 @@ -203,14 +217,14 @@ def test_code2explain_sample(self, mock_server: MockServer): 'index': 0, 'content': '这个Java函数是一个覆盖了另一个方法的函数,名为`getHeaderCacheSize()`。这个方法是从另一个已覆盖的方法继承过来的。在`@Override`声明中,可以确定这个函数覆盖了一个其他的函数。这个函数的返回类型是`int`。\n\n函数内容是:返回0。这个值意味着在`getHeaderCacheSize()`方法中,不会进行任何处理或更新。因此,返回的`0`值应该是没有被处理或更新的值。\n\n总的来说,这个函数的作用可能是为了让另一个方法返回一个预设的值。但是由于`@Override`的提示,我们无法确定它的真正目的,需要进一步查看代码才能得到更多的信息。', # noqa E501 - 'frame_id': 30 - }] + 'frame_id': 30, + }], }, 'usage': { 'output_tokens': 235, - 'input_tokens': 55 + 'input_tokens': 55, }, - 'request_id': '089e525f-d28f-9e08-baa2-01dde87c90a7' + 'request_id': '089e525f-d28f-9e08-baa2-01dde87c90a7', } response_str = json.dumps(response_obj) mock_server.responses.put(response_str) @@ -223,20 +237,23 @@ def test_code2explain_sample(self, mock_server: MockServer): meta={ 'code': '@Override\n public int getHeaderCacheSize()\n {\n return 0;\n }\n\n', # noqa E501 - 'language': 'java' - }) - ]) + 'language': 'java', + }, + ), + ], + ) req = mock_server.requests.get(block=True) assert req['model'] == model assert req['input']['scene'] == 'code2explain' assert json.dumps( - req['input']['message'], ensure_ascii=False + req['input']['message'], ensure_ascii=False, ) == '[{"role": "user", "content": "要求不低于200字"}, {"role": "attachment", "meta": {"code": "@Override\n public int getHeaderCacheSize()\n {\n return 0;\n }\n\n", "language": "java"}}]'.replace('\t', '\\t').replace('\n', '\\n') # noqa E501 assert response.status_code == HTTPStatus.OK assert response.request_id == '089e525f-d28f-9e08-baa2-01dde87c90a7' assert response.output['choices'][0][ - 'content'] == '这个Java函数是一个覆盖了另一个方法的函数,名为`getHeaderCacheSize()`。这个方法是从另一个已覆盖的方法继承过来的。在`@Override`声明中,可以确定这个函数覆盖了一个其他的函数。这个函数的返回类型是`int`。\n\n函数内容是:返回0。这个值意味着在`getHeaderCacheSize()`方法中,不会进行任何处理或更新。因此,返回的`0`值应该是没有被处理或更新的值。\n\n总的来说,这个函数的作用可能是为了让另一个方法返回一个预设的值。但是由于`@Override`的提示,我们无法确定它的真正目的,需要进一步查看代码才能得到更多的信息。' # noqa E501 + 'content' + ] == '这个Java函数是一个覆盖了另一个方法的函数,名为`getHeaderCacheSize()`。这个方法是从另一个已覆盖的方法继承过来的。在`@Override`声明中,可以确定这个函数覆盖了一个其他的函数。这个函数的返回类型是`int`。\n\n函数内容是:返回0。这个值意味着在`getHeaderCacheSize()`方法中,不会进行任何处理或更新。因此,返回的`0`值应该是没有被处理或更新的值。\n\n总的来说,这个函数的作用可能是为了让另一个方法返回一个预设的值。但是由于`@Override`的提示,我们无法确定它的真正目的,需要进一步查看代码才能得到更多的信息。' # noqa E501 assert response.output['choices'][0]['frame_id'] == 30 assert response.output['choices'][0]['finish_reason'] == 'stop' assert response.usage['output_tokens'] == 235 @@ -250,14 +267,14 @@ def test_commit2msg_sample(self, mock_server: MockServer): 'frame_timestamp': 1694697276.4451804, 'index': 0, 'content': 'Remove old listFolder method', - 'frame_id': 1 - }] + 'frame_id': 1, + }], }, 'usage': { 'output_tokens': 5, - 'input_tokens': 197 + 'input_tokens': 197, }, - 'request_id': '8f400a4e-6448-94ab-89bf-a97b1a7e6fe6' + 'request_id': '8f400a4e-6448-94ab-89bf-a97b1a7e6fe6', } response_str = json.dumps(response_obj) mock_server.responses.put(response_str) @@ -273,21 +290,24 @@ def test_commit2msg_sample(self, mock_server: MockServer): 'old_file_path': 'src/com/siondream/core/PlatformResolver.java', 'new_file_path': - 'src/com/siondream/core/PlatformResolver.java' - }] - }) - ]) + 'src/com/siondream/core/PlatformResolver.java', + }], + }, + ), + ], + ) req = mock_server.requests.get(block=True) assert req['model'] == model assert req['input']['scene'] == 'commit2msg' assert json.dumps( - req['input']['message'], ensure_ascii=False + req['input']['message'], ensure_ascii=False, ) == '[{"role": "attachment", "meta": {"diff_list": [{"diff": "--- src/com/siondream/core/PlatformResolver.java\n+++ src/com/siondream/core/PlatformResolver.java\n@@ -1,11 +1,8 @@\npackage com.siondream.core;\n-\n-import com.badlogic.gdx.files.FileHandle;\n\npublic interface PlatformResolver {\npublic void openURL(String url);\npublic void rateApp();\npublic void sendFeedback();\n-\tpublic FileHandle[] listFolder(String path);\n}\n", "old_file_path": "src/com/siondream/core/PlatformResolver.java", "new_file_path": "src/com/siondream/core/PlatformResolver.java"}]}}]'.replace('\t', '\\t').replace('\n', '\\n') # noqa E501 assert response.status_code == HTTPStatus.OK assert response.request_id == '8f400a4e-6448-94ab-89bf-a97b1a7e6fe6' assert response.output['choices'][0][ - 'content'] == 'Remove old listFolder method' + 'content' + ] == 'Remove old listFolder method' assert response.output['choices'][0]['frame_id'] == 1 assert response.output['choices'][0]['finish_reason'] == 'stop' assert response.usage['output_tokens'] == 5 @@ -302,14 +322,14 @@ def test_unittest_sample(self, mock_server: MockServer): 'index': 0, 'content': "这个函数用于解析时间戳映射表的输入字符串并返回该映射表的实例。函数有两个必选参数:typeClass - 用于标识数据类型的泛型;input - 输入的时间戳映射表字符串。如果typeClass为null,将抛出IllegalArgumentException异常;如果input为null,则返回null。函数内部首先检查输入的字符串是否等于\"空字符串\",如果是,则直接返回null;如果不是,则创建TimestampMap的实例,并使用input字符串创建字符串Reader对象。然后使用读取器逐个字符解析时间戳字符串,并在解析完成后返回相应的TimestampMap对象。函数的行为取决于传入的时间戳字符串类型。", # noqa E501 - 'frame_id': 29 - }] + 'frame_id': 29, + }], }, 'usage': { 'output_tokens': 227, - 'input_tokens': 659 + 'input_tokens': 659, }, - 'request_id': '6ec31e35-f355-9289-a18d-103abc36dece' + 'request_id': '6ec31e35-f355-9289-a18d-103abc36dece', } response_str = json.dumps(response_obj) mock_server.responses.put(response_str) @@ -321,21 +341,25 @@ def test_unittest_sample(self, mock_server: MockServer): meta={ 'code': "public static TimestampMap parseTimestampMap(Class typeClass, String input, DateTimeZone timeZone) throws IllegalArgumentException {\n if (typeClass == null) {\n throw new IllegalArgumentException(\"typeClass required\");\n }\n\n if (input == null) {\n return null;\n }\n\n TimestampMap result;\n\n typeClass = AttributeUtils.getStandardizedType(typeClass);\n if (typeClass.equals(String.class)) {\n result = new TimestampStringMap();\n } else if (typeClass.equals(Byte.class)) {\n result = new TimestampByteMap();\n } else if (typeClass.equals(Short.class)) {\n result = new TimestampShortMap();\n } else if (typeClass.equals(Integer.class)) {\n result = new TimestampIntegerMap();\n } else if (typeClass.equals(Long.class)) {\n result = new TimestampLongMap();\n } else if (typeClass.equals(Float.class)) {\n result = new TimestampFloatMap();\n } else if (typeClass.equals(Double.class)) {\n result = new TimestampDoubleMap();\n } else if (typeClass.equals(Boolean.class)) {\n result = new TimestampBooleanMap();\n } else if (typeClass.equals(Character.class)) {\n result = new TimestampCharMap();\n } else {\n throw new IllegalArgumentException(\"Unsupported type \" + typeClass.getClass().getCanonicalName());\n }\n\n if (input.equalsIgnoreCase(EMPTY_VALUE)) {\n return result;\n }\n\n StringReader reader = new StringReader(input + ' ');// Add 1 space so\n // reader.skip\n // function always\n // works when\n // necessary (end of\n // string not\n // reached).\n\n try {\n int r;\n char c;\n while ((r = reader.read()) != -1) {\n c = (char) r;\n switch (c) {\n case LEFT_BOUND_SQUARE_BRACKET:\n case LEFT_BOUND_BRACKET:\n parseTimestampAndValue(typeClass, reader, result, timeZone);\n break;\n default:\n // Ignore other chars outside of bounds\n }\n }\n } catch (IOException ex) {\n throw new RuntimeException(\"Unexpected expection while parsing timestamps\", ex);\n }\n\n return result;\n }", # noqa E501 - 'language': 'java' - }) - ]) + 'language': 'java', + }, + ), + ], + ) req = mock_server.requests.get(block=True) assert req['model'] == model assert req['input']['scene'] == 'unittest' assert req['input']['message'][0]['role'] == 'attachment' assert req['input']['message'][0]['meta'][ - 'code'] == """public static TimestampMap parseTimestampMap(Class typeClass, String input, DateTimeZone timeZone) throws IllegalArgumentException {\n if (typeClass == null) {\n throw new IllegalArgumentException(\"typeClass required\");\n }\n\n if (input == null) {\n return null;\n }\n\n TimestampMap result;\n\n typeClass = AttributeUtils.getStandardizedType(typeClass);\n if (typeClass.equals(String.class)) {\n result = new TimestampStringMap();\n } else if (typeClass.equals(Byte.class)) {\n result = new TimestampByteMap();\n } else if (typeClass.equals(Short.class)) {\n result = new TimestampShortMap();\n } else if (typeClass.equals(Integer.class)) {\n result = new TimestampIntegerMap();\n } else if (typeClass.equals(Long.class)) {\n result = new TimestampLongMap();\n } else if (typeClass.equals(Float.class)) {\n result = new TimestampFloatMap();\n } else if (typeClass.equals(Double.class)) {\n result = new TimestampDoubleMap();\n } else if (typeClass.equals(Boolean.class)) {\n result = new TimestampBooleanMap();\n } else if (typeClass.equals(Character.class)) {\n result = new TimestampCharMap();\n } else {\n throw new IllegalArgumentException(\"Unsupported type \" + typeClass.getClass().getCanonicalName());\n }\n\n if (input.equalsIgnoreCase(EMPTY_VALUE)) {\n return result;\n }\n\n StringReader reader = new StringReader(input + ' ');// Add 1 space so\n // reader.skip\n // function always\n // works when\n // necessary (end of\n // string not\n // reached).\n\n try {\n int r;\n char c;\n while ((r = reader.read()) != -1) {\n c = (char) r;\n switch (c) {\n case LEFT_BOUND_SQUARE_BRACKET:\n case LEFT_BOUND_BRACKET:\n parseTimestampAndValue(typeClass, reader, result, timeZone);\n break;\n default:\n // Ignore other chars outside of bounds\n }\n }\n } catch (IOException ex) {\n throw new RuntimeException(\"Unexpected expection while parsing timestamps\", ex);\n }\n\n return result;\n }""" # noqa E501 + 'code' + ] == """public static TimestampMap parseTimestampMap(Class typeClass, String input, DateTimeZone timeZone) throws IllegalArgumentException {\n if (typeClass == null) {\n throw new IllegalArgumentException(\"typeClass required\");\n }\n\n if (input == null) {\n return null;\n }\n\n TimestampMap result;\n\n typeClass = AttributeUtils.getStandardizedType(typeClass);\n if (typeClass.equals(String.class)) {\n result = new TimestampStringMap();\n } else if (typeClass.equals(Byte.class)) {\n result = new TimestampByteMap();\n } else if (typeClass.equals(Short.class)) {\n result = new TimestampShortMap();\n } else if (typeClass.equals(Integer.class)) {\n result = new TimestampIntegerMap();\n } else if (typeClass.equals(Long.class)) {\n result = new TimestampLongMap();\n } else if (typeClass.equals(Float.class)) {\n result = new TimestampFloatMap();\n } else if (typeClass.equals(Double.class)) {\n result = new TimestampDoubleMap();\n } else if (typeClass.equals(Boolean.class)) {\n result = new TimestampBooleanMap();\n } else if (typeClass.equals(Character.class)) {\n result = new TimestampCharMap();\n } else {\n throw new IllegalArgumentException(\"Unsupported type \" + typeClass.getClass().getCanonicalName());\n }\n\n if (input.equalsIgnoreCase(EMPTY_VALUE)) {\n return result;\n }\n\n StringReader reader = new StringReader(input + ' ');// Add 1 space so\n // reader.skip\n // function always\n // works when\n // necessary (end of\n // string not\n // reached).\n\n try {\n int r;\n char c;\n while ((r = reader.read()) != -1) {\n c = (char) r;\n switch (c) {\n case LEFT_BOUND_SQUARE_BRACKET:\n case LEFT_BOUND_BRACKET:\n parseTimestampAndValue(typeClass, reader, result, timeZone);\n break;\n default:\n // Ignore other chars outside of bounds\n }\n }\n } catch (IOException ex) {\n throw new RuntimeException(\"Unexpected expection while parsing timestamps\", ex);\n }\n\n return result;\n }""" # noqa E501 assert req['input']['message'][0]['meta']['language'] == 'java' assert response.status_code == HTTPStatus.OK assert response.request_id == '6ec31e35-f355-9289-a18d-103abc36dece' assert response.output['choices'][0][ - 'content'] == "这个函数用于解析时间戳映射表的输入字符串并返回该映射表的实例。函数有两个必选参数:typeClass - 用于标识数据类型的泛型;input - 输入的时间戳映射表字符串。如果typeClass为null,将抛出IllegalArgumentException异常;如果input为null,则返回null。函数内部首先检查输入的字符串是否等于\"空字符串\",如果是,则直接返回null;如果不是,则创建TimestampMap的实例,并使用input字符串创建字符串Reader对象。然后使用读取器逐个字符解析时间戳字符串,并在解析完成后返回相应的TimestampMap对象。函数的行为取决于传入的时间戳字符串类型。" # noqa E501 + 'content' + ] == "这个函数用于解析时间戳映射表的输入字符串并返回该映射表的实例。函数有两个必选参数:typeClass - 用于标识数据类型的泛型;input - 输入的时间戳映射表字符串。如果typeClass为null,将抛出IllegalArgumentException异常;如果input为null,则返回null。函数内部首先检查输入的字符串是否等于\"空字符串\",如果是,则直接返回null;如果不是,则创建TimestampMap的实例,并使用input字符串创建字符串Reader对象。然后使用读取器逐个字符解析时间戳字符串,并在解析完成后返回相应的TimestampMap对象。函数的行为取决于传入的时间戳字符串类型。" # noqa E501 assert response.output['choices'][0]['frame_id'] == 29 assert response.output['choices'][0]['finish_reason'] == 'stop' assert response.usage['output_tokens'] == 227 @@ -350,14 +374,14 @@ def test_codeqa_sample(self, mock_server: MockServer): 'index': 0, 'content': "Yes, this is possible:\nclass MyRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):\n [...]\n\n def doGET(self):\n # some stuff\n if \"X-Port\" in self.headers:\n # change the port in this request\n self.server_port = int(self.headers[\"X-Port\"])\n print(\"Changed port: %s\" % self.server_port)\n [...]\n\nclass ThreadingHTTPServer(ThreadingMixIn, HTTPServer): \n pass\n\nserver = ThreadingHTTPServer(('localhost', self.server_port), MyRequestHandler)\nserver.serve_forever()", # noqa E501 - 'frame_id': 19 - }] + 'frame_id': 19, + }], }, 'usage': { 'output_tokens': 150, - 'input_tokens': 127 + 'input_tokens': 127, }, - 'request_id': 'e09386b7-5171-96b0-9c6f-7128507e14e6' + 'request_id': 'e09386b7-5171-96b0-9c6f-7128507e14e6', } response_str = json.dumps(response_obj) mock_server.responses.put(response_str) @@ -366,20 +390,23 @@ def test_codeqa_sample(self, mock_server: MockServer): scene=CodeGeneration.Scenes.code_qa, message=[ UserRoleMessageParam( - content="I'm writing a small web server in Python, using BaseHTTPServer and a custom subclass of BaseHTTPServer.BaseHTTPRequestHandler. Is it possible to make this listen on more than one port?\nWhat I'm doing now:\nclass MyRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):\n def doGET\n [...]\n\nclass ThreadingHTTPServer(ThreadingMixIn, HTTPServer): \n pass\n\nserver = ThreadingHTTPServer(('localhost', 80), MyRequestHandler)\nserver.serve_forever()" # noqa E501 - ) - ]) + content="I'm writing a small web server in Python, using BaseHTTPServer and a custom subclass of BaseHTTPServer.BaseHTTPRequestHandler. Is it possible to make this listen on more than one port?\nWhat I'm doing now:\nclass MyRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):\n def doGET\n [...]\n\nclass ThreadingHTTPServer(ThreadingMixIn, HTTPServer): \n pass\n\nserver = ThreadingHTTPServer(('localhost', 80), MyRequestHandler)\nserver.serve_forever()", # noqa E501 + ), + ], + ) req = mock_server.requests.get(block=True) assert req['model'] == model assert req['input']['scene'] == 'codeqa' assert req['input']['message'][0]['role'] == 'user' assert req['input']['message'][0][ - 'content'] == """I'm writing a small web server in Python, using BaseHTTPServer and a custom subclass of BaseHTTPServer.BaseHTTPRequestHandler. Is it possible to make this listen on more than one port?\nWhat I'm doing now:\nclass MyRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):\n def doGET\n [...]\n\nclass ThreadingHTTPServer(ThreadingMixIn, HTTPServer): \n pass\n\nserver = ThreadingHTTPServer(('localhost', 80), MyRequestHandler)\nserver.serve_forever()""" # noqa E501 + 'content' + ] == """I'm writing a small web server in Python, using BaseHTTPServer and a custom subclass of BaseHTTPServer.BaseHTTPRequestHandler. Is it possible to make this listen on more than one port?\nWhat I'm doing now:\nclass MyRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):\n def doGET\n [...]\n\nclass ThreadingHTTPServer(ThreadingMixIn, HTTPServer): \n pass\n\nserver = ThreadingHTTPServer(('localhost', 80), MyRequestHandler)\nserver.serve_forever()""" # noqa E501 assert response.status_code == HTTPStatus.OK assert response.request_id == 'e09386b7-5171-96b0-9c6f-7128507e14e6' assert response.output['choices'][0][ - 'content'] == "Yes, this is possible:\nclass MyRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):\n [...]\n\n def doGET(self):\n # some stuff\n if \"X-Port\" in self.headers:\n # change the port in this request\n self.server_port = int(self.headers[\"X-Port\"])\n print(\"Changed port: %s\" % self.server_port)\n [...]\n\nclass ThreadingHTTPServer(ThreadingMixIn, HTTPServer): \n pass\n\nserver = ThreadingHTTPServer(('localhost', self.server_port), MyRequestHandler)\nserver.serve_forever()" # noqa E501 + 'content' + ] == "Yes, this is possible:\nclass MyRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):\n [...]\n\n def doGET(self):\n # some stuff\n if \"X-Port\" in self.headers:\n # change the port in this request\n self.server_port = int(self.headers[\"X-Port\"])\n print(\"Changed port: %s\" % self.server_port)\n [...]\n\nclass ThreadingHTTPServer(ThreadingMixIn, HTTPServer): \n pass\n\nserver = ThreadingHTTPServer(('localhost', self.server_port), MyRequestHandler)\nserver.serve_forever()" # noqa E501 assert response.output['choices'][0]['frame_id'] == 19 assert response.output['choices'][0]['finish_reason'] == 'stop' assert response.usage['output_tokens'] == 150 @@ -394,14 +421,14 @@ def test_nl2sql_sample(self, mock_server: MockServer): 'index': 0, 'content': "SELECT SUM(score) as '小明的总分数' FROM student_score WHERE name = '小明';", - 'frame_id': 3 - }] + 'frame_id': 3, + }], }, 'usage': { 'output_tokens': 25, - 'input_tokens': 420 + 'input_tokens': 420, }, - 'request_id': 'e61a35b7-db6f-90c2-8677-9620ffea63b6' + 'request_id': 'e61a35b7-db6f-90c2-8677-9620ffea63b6', } response_str = json.dumps(response_obj) mock_server.responses.put(response_str) @@ -414,36 +441,40 @@ def test_nl2sql_sample(self, mock_server: MockServer): meta={ 'synonym_infos': { '学生姓名': '姓名|名字|名称', - '学生分数': '分数|得分' + '学生分数': '分数|得分', }, 'recall_infos': [{ 'content': "student_score.id='小明'", - 'score': '0.83' + 'score': '0.83', }], 'schema_infos': [{ 'table_id': 'student_score', 'table_desc': '学生分数表', - 'columns': [{ - 'col_name': 'id', - 'col_caption': '学生id', - 'col_desc': '例值为:1,2,3', - 'col_type': 'string' - }, { - 'col_name': 'name', - 'col_caption': '学生姓名', - 'col_desc': '例值为:张三,李四,小明', - 'col_type': 'string' - }, { - 'col_name': 'score', - 'col_caption': '学生分数', - 'col_desc': '例值为:98,100,66', - 'col_type': 'string' - }] - }] - }) - ]) + 'columns': [ + { + 'col_name': 'id', + 'col_caption': '学生id', + 'col_desc': '例值为:1,2,3', + 'col_type': 'string', + }, { + 'col_name': 'name', + 'col_caption': '学生姓名', + 'col_desc': '例值为:张三,李四,小明', + 'col_type': 'string', + }, { + 'col_name': 'score', + 'col_caption': '学生分数', + 'col_desc': '例值为:98,100,66', + 'col_type': 'string', + }, + ], + }], + }, + ), + ], + ) req = mock_server.requests.get(block=True) assert req['model'] == model assert req['input']['scene'] == 'nl2sql' @@ -452,7 +483,8 @@ def test_nl2sql_sample(self, mock_server: MockServer): assert response.status_code == HTTPStatus.OK assert response.request_id == 'e61a35b7-db6f-90c2-8677-9620ffea63b6' assert response.output['choices'][0][ - 'content'] == "SELECT SUM(score) as '小明的总分数' FROM student_score WHERE name = '小明';" + 'content' + ] == "SELECT SUM(score) as '小明的总分数' FROM student_score WHERE name = '小明';" assert response.output['choices'][0]['frame_id'] == 3 assert response.output['choices'][0]['finish_reason'] == 'stop' assert response.usage['output_tokens'] == 25 diff --git a/tests/test_conversation.py b/tests/test_conversation.py index 79d6b03..11ffe80 100644 --- a/tests/test_conversation.py +++ b/tests/test_conversation.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -12,7 +13,7 @@ def request_generator(): - yield 'hello' + yield "hello" model = Generation.Models.qwen_turbo @@ -20,187 +21,209 @@ def request_generator(): class TestConversationRequest(MockServerBase): text_response_obj = { - 'status_code': 200, - 'request_id': 'effd2cb1-1a8c-9f18-9a49-a396f673bd40', - 'code': '', - 'message': '', - 'output': { - 'text': 'hello', - 'choices': None, - 'finish_reason': 'stop' + "status_code": 200, + "request_id": "effd2cb1-1a8c-9f18-9a49-a396f673bd40", + "code": "", + "message": "", + "output": { + "text": "hello", + "choices": None, + "finish_reason": "stop", + }, + "usage": { + "input_tokens": 27, + "output_tokens": 110, }, - 'usage': { - 'input_tokens': 27, - 'output_tokens': 110 - } } message_response_obj = { - 'status_code': 200, - 'request_id': '1', - 'code': '', - 'message': '', - 'output': { - 'text': - None, - 'choices': [{ - 'finish_reason': 'stop', - 'message': { - 'role': 'assistant', - 'content': 'hello world' - } - }] + "status_code": 200, + "request_id": "1", + "code": "", + "message": "", + "output": { + "text": None, + "choices": [ + { + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "hello world", + }, + }, + ], + }, + "usage": { + "input_tokens": 27, + "output_tokens": 110, }, - 'usage': { - 'input_tokens': 27, - 'output_tokens': 110 - } } def test_message_request(self, mock_server: MockServer): response_str = json.dumps(TestConversationRequest.message_response_obj) mock_server.responses.put(response_str) - prompt = 'hello' - messages = [{'role': 'user', 'content': prompt}] - resp = Generation.call(model=model, - messages=messages, - max_tokens=1024, - api_protocol='http', - result_format='message', - n=50) + prompt = "hello" + messages = [{"role": "user", "content": prompt}] + resp = Generation.call( + model=model, + messages=messages, + max_tokens=1024, + api_protocol="http", + result_format="message", + n=50, + ) req = mock_server.requests.get() - assert req['model'] == model - assert req['parameters']['result_format'] == 'message' - assert req['input']['messages'][0] == Message(role=Role.USER, - content=prompt) + assert req["model"] == model + assert req["parameters"]["result_format"] == "message" + assert req["input"]["messages"][0] == Message( + role=Role.USER, + content=prompt, + ) assert resp.output.text is None - assert resp.output.choices[0] == Choice(finish_reason='stop', - message={ - 'role': 'assistant', - 'content': 'hello world' - }) + assert resp.output.choices[0] == Choice( + finish_reason="stop", + message={ + "role": "assistant", + "content": "hello world", + }, + ) def test_conversation_with_history(self, mock_server: MockServer): history = History() - item = HistoryItem('user', text='今天天气好吗') + item = HistoryItem("user", text="今天天气好吗") history.append(item) - item = HistoryItem('bot', text='今天天气不错,要出去玩玩嘛?') + item = HistoryItem("bot", text="今天天气不错,要出去玩玩嘛?") history.append(item) - item = HistoryItem('user', text='那你有什么地方推荐?') + item = HistoryItem("user", text="那你有什么地方推荐?") history.append(item) - item = HistoryItem('bot', text='我建议你去公园,春天来了,花朵开了,很美丽。') + item = HistoryItem("bot", text="我建议你去公园,春天来了,花朵开了,很美丽。") history.append(item) mock_server.responses.put( - json.dumps(TestConversationRequest.text_response_obj)) + json.dumps(TestConversationRequest.text_response_obj), + ) chat = Conversation(history) response = chat.call( model, - prompt='推荐一个附近的公园', + prompt="推荐一个附近的公园", auto_history=True, ) assert response.status_code == HTTPStatus.OK - assert response.output.text == 'hello' + assert response.output.text == "hello" assert response.output.choices is None - assert response.output.finish_reason == 'stop' + assert response.output.finish_reason == "stop" req = mock_server.requests.get(block=True) - assert req['model'] == model - assert req['parameters'] == {} - assert len(req['input']['history']) == 2 + assert req["model"] == model + assert req["parameters"] == {} + assert len(req["input"]["history"]) == 2 mock_server.responses.put( - json.dumps(TestConversationRequest.text_response_obj)) + json.dumps(TestConversationRequest.text_response_obj), + ) response = chat.call( model, - prompt='这个公园去过很多次了,远一点的呢', + prompt="这个公园去过很多次了,远一点的呢", auto_history=True, ) assert response.status_code == HTTPStatus.OK - assert response.output.text == 'hello' + assert response.output.text == "hello" assert response.output.choices is None - assert response.output.finish_reason == 'stop' + assert response.output.finish_reason == "stop" req = mock_server.requests.get(block=True) print(req) - assert req['model'] == model - assert req['parameters'] == {} - assert len(req['input']['history']) == 3 - - def test_conversation_with_message_and_prompt(self, - mock_server: MockServer): + assert req["model"] == model + assert req["parameters"] == {} + assert len(req["input"]["history"]) == 3 + + def test_conversation_with_message_and_prompt( + self, + mock_server: MockServer, + ): messageManager = MessageManager(10) - messageManager.add(Message(role='system', content='你是达摩院的生活助手机器人。')) + messageManager.add(Message(role="system", content="你是达摩院的生活助手机器人。")) mock_server.responses.put( - json.dumps(TestConversationRequest.message_response_obj)) + json.dumps(TestConversationRequest.message_response_obj), + ) conv = Conversation() response = conv.call( model, - prompt='推荐一个附近的公园', + prompt="推荐一个附近的公园", messages=messageManager.get(), - result_format='message', + result_format="message", ) assert response.status_code == HTTPStatus.OK assert response.output.text is None - choices = TestConversationRequest.message_response_obj['output'][ - 'choices'] + choices = TestConversationRequest.message_response_obj["output"][ + "choices" + ] assert response.output.choices == choices req = mock_server.requests.get(block=True) - assert req['model'] == model - assert req['parameters'] == {'result_format': 'message'} - assert len(req['input']['messages']) == 2 + assert req["model"] == model + assert req["parameters"] == {"result_format": "message"} + assert len(req["input"]["messages"]) == 2 def test_conversation_with_messages(self, mock_server: MockServer): messageManager = MessageManager(10) - messageManager.add(Message(role='system', content='你是达摩院的生活助手机器人。')) - messageManager.add(Message(role=Role.USER, content='推荐一个附近的公园')) + messageManager.add(Message(role="system", content="你是达摩院的生活助手机器人。")) + messageManager.add(Message(role=Role.USER, content="推荐一个附近的公园")) mock_server.responses.put( - json.dumps(TestConversationRequest.message_response_obj)) + json.dumps(TestConversationRequest.message_response_obj), + ) conv = Conversation() - response = conv.call(model, - messages=messageManager.get(), - result_format='message') + response = conv.call( + model, + messages=messageManager.get(), + result_format="message", + ) assert response.status_code == HTTPStatus.OK assert response.output.text is None - choices = TestConversationRequest.message_response_obj['output'][ - 'choices'] + choices = TestConversationRequest.message_response_obj["output"][ + "choices" + ] assert response.output.choices == choices req = mock_server.requests.get(block=True) - assert req['model'] == model - assert req['parameters'] == {'result_format': 'message'} - assert len(req['input']['messages']) == 2 + assert req["model"] == model + assert req["parameters"] == {"result_format": "message"} + assert len(req["input"]["messages"]) == 2 def test_conversation_call_with_messages(self, mock_server: MockServer): messageManager = MessageManager(10) - messageManager.add(Message(role='system', content='你是达摩院的生活助手机器人。')) - messageManager.add(Message(role=Role.USER, content='推荐一个附近的公园')) + messageManager.add(Message(role="system", content="你是达摩院的生活助手机器人。")) + messageManager.add(Message(role=Role.USER, content="推荐一个附近的公园")) mock_server.responses.put( - json.dumps(TestConversationRequest.message_response_obj)) + json.dumps(TestConversationRequest.message_response_obj), + ) conv = Conversation() response = conv.call( model, messages=messageManager.get(), - result_format='message', + result_format="message", ) assert response.status_code == HTTPStatus.OK assert response.output.text is None - choices = TestConversationRequest.message_response_obj['output'][ - 'choices'] + choices = TestConversationRequest.message_response_obj["output"][ + "choices" + ] assert response.output.choices == choices req = mock_server.requests.get(block=True) - assert req['model'] == model - assert req['parameters'] == {'result_format': 'message'} - assert len(req['input']['messages']) == 2 + assert req["model"] == model + assert req["parameters"] == {"result_format": "message"} + assert len(req["input"]["messages"]) == 2 def test_not_qwen(self, mock_server: MockServer): - prompt = '介绍下杭州' + prompt = "介绍下杭州" mock_server.responses.put( - json.dumps(TestConversationRequest.text_response_obj)) - response = Generation.call(model=Generation.Models.dolly_12b_v2, - prompt=prompt) + json.dumps(TestConversationRequest.text_response_obj), + ) + response = Generation.call( + model=Generation.Models.dolly_12b_v2, + prompt=prompt, + ) assert response.status_code == HTTPStatus.OK - assert response.output.text == 'hello' + assert response.output.text == "hello" assert response.output.choices is None - assert response.output.finish_reason == 'stop' + assert response.output.finish_reason == "stop" req = mock_server.requests.get(block=True) - assert req['model'] == Generation.Models.dolly_12b_v2 - assert req['parameters'] == {} - assert req['input'] == {'prompt': prompt} + assert req["model"] == Generation.Models.dolly_12b_v2 + assert req["parameters"] == {} + assert req["input"] == {"prompt": prompt} diff --git a/tests/test_encryption.py b/tests/test_encryption.py index ec50417..ca0d6f4 100644 --- a/tests/test_encryption.py +++ b/tests/test_encryption.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -5,7 +6,6 @@ class TestEncryption: - @staticmethod def test_get_public_keys(): pub_keys = Encryption._get_public_keys() @@ -34,9 +34,11 @@ def test_encrypt_with_aes(): @staticmethod def test_encrypt_aes_key_with_rsa(): public_keys = Encryption._get_public_keys() - public_key = public_keys.get('public_key') + public_key = public_keys.get("public_key") aes_key = Encryption._generate_aes_secret_key() - cipher_aes_key = Encryption._encrypt_aes_key_with_rsa(aes_key, public_key) + cipher_aes_key = Encryption._encrypt_aes_key_with_rsa( + aes_key, + public_key, + ) print(f"\ncipher_aes_key: {cipher_aes_key}") - diff --git a/tests/test_http_api.py b/tests/test_http_api.py index d02a6c4..009f12e 100644 --- a/tests/test_http_api.py +++ b/tests/test_http_api.py @@ -1,94 +1,114 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from http import HTTPStatus from dashscope import Generation -from tests.constants import (TEST_DISABLE_DATA_INSPECTION_REQUEST_ID, - TEST_ENABLE_DATA_INSPECTION_REQUEST_ID) +from tests.constants import ( + TEST_DISABLE_DATA_INSPECTION_REQUEST_ID, + TEST_ENABLE_DATA_INSPECTION_REQUEST_ID, +) from tests.http_task_request import HttpRequest from tests.mock_request_base import MockRequestBase def request_generator(): - yield 'hello' + yield "hello" class TestHttpRequest(MockRequestBase): def test_independent_model_sync_batch_request(self, http_server): resp = Generation.call( model=Generation.Models.qwen_turbo, - prompt='hello', + prompt="hello", max_tokens=1024, - api_protocol='http', + api_protocol="http", n=50, - headers={'request_id': TEST_DISABLE_DATA_INSPECTION_REQUEST_ID}) - assert resp.output.text == 'hello' - assert resp.output['text'] == 'hello' + headers={"request_id": TEST_DISABLE_DATA_INSPECTION_REQUEST_ID}, + ) + assert resp.output.text == "hello" + assert resp.output["text"] == "hello" - def test_disable_data_inspection(self, http_server, - mock_disable_data_inspection_env): + def test_disable_data_inspection( + self, + http_server, + mock_disable_data_inspection_env, + ): resp = Generation.call( model=Generation.Models.qwen_turbo, - prompt='hello', + prompt="hello", max_tokens=1024, - api_protocol='http', + api_protocol="http", n=50, - headers={'request_id': TEST_DISABLE_DATA_INSPECTION_REQUEST_ID}) - assert resp.output.text == 'hello' - assert resp.output['text'] == 'hello' + headers={"request_id": TEST_DISABLE_DATA_INSPECTION_REQUEST_ID}, + ) + assert resp.output.text == "hello" + assert resp.output["text"] == "hello" - def test_enable_data_inspection(self, http_server, - mock_enable_data_inspection_env): + def test_enable_data_inspection( + self, + http_server, + mock_enable_data_inspection_env, + ): resp = Generation.call( model=Generation.Models.qwen_turbo, - prompt='hello', + prompt="hello", max_tokens=1024, - api_protocol='http', + api_protocol="http", n=50, - headers={'request_id': TEST_ENABLE_DATA_INSPECTION_REQUEST_ID}) - assert resp.output.text == 'hello' - assert resp.output['text'] == 'hello' + headers={"request_id": TEST_ENABLE_DATA_INSPECTION_REQUEST_ID}, + ) + assert resp.output.text == "hello" + assert resp.output["text"] == "hello" def test_independent_model_sync_stream_request(self, http_server): - resp = Generation.call(model=Generation.Models.qwen_turbo, - prompt='hello', - max_tokens=1024, - stream=True, - api_protocol='http', - n=50) + resp = Generation.call( + model=Generation.Models.qwen_turbo, + prompt="hello", + max_tokens=1024, + stream=True, + api_protocol="http", + n=50, + ) for idx, rsp in enumerate(resp): assert rsp.output.text == str(idx) - print(rsp.output['text']) + print(rsp.output["text"]) def test_echo_request_with_file_object(self, http_server): - with open('tests/data/request_file.bin') as f: - resp = Generation.call(model=Generation.Models.qwen_turbo, - prompt=f, - max_tokens=1024, - api_protocol='http', - n=50) - assert resp.output.text[0] == 'hello' + with open("tests/data/request_file.bin") as f: + resp = Generation.call( + model=Generation.Models.qwen_turbo, + prompt=f, + max_tokens=1024, + api_protocol="http", + n=50, + ) + assert resp.output.text[0] == "hello" def test_echo_request_with_generator(self, http_server): - resp = Generation.call(model=Generation.Models.qwen_turbo, - prompt=request_generator(), - max_tokens=1024, - api_protocol='http', - n=50) - assert resp.output.text == 'hello' + resp = Generation.call( + model=Generation.Models.qwen_turbo, + prompt=request_generator(), + max_tokens=1024, + api_protocol="http", + n=50, + ) + assert resp.output.text == "hello" def test_send_receive_files(self, http_server): - bird_file = open('tests/data/bird.JPEG', 'rb') - dogs_file = open('tests/data/dogs.jpg', 'rb') - resp = HttpRequest.call(model=Generation.Models.qwen_turbo, - prompt='hello', - task_group='aigc', - task='image-generation', - function='generation', - api_protocol='http', - request_id='1111111111', - form={ - 'bird': bird_file, - 'dog': dogs_file - }) + bird_file = open("tests/data/bird.JPEG", "rb") + dogs_file = open("tests/data/dogs.jpg", "rb") + resp = HttpRequest.call( + model=Generation.Models.qwen_turbo, + prompt="hello", + task_group="aigc", + task="image-generation", + function="generation", + api_protocol="http", + request_id="1111111111", + form={ + "bird": bird_file, + "dog": dogs_file, + }, + ) assert resp.status_code == HTTPStatus.OK diff --git a/tests/test_http_deployments_api.py b/tests/test_http_deployments_api.py index f339bb1..43fd934 100644 --- a/tests/test_http_deployments_api.py +++ b/tests/test_http_deployments_api.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from http import HTTPStatus @@ -9,24 +10,26 @@ class TestDeploymentRequest(MockRequestBase): def test_create_deployment_tune_job(self, http_server): - resp = Deployments.call(model='gpt', - suffix='1', - capacity=2, - headers={'X-Request-Id': '111111'}) + resp = Deployments.call( + model="gpt", + suffix="1", + capacity=2, + headers={"X-Request-Id": "111111"}, + ) assert resp.status_code == HTTPStatus.OK - assert resp.output['deployed_model'] == 'deploy123456' - assert resp.output['status'] == 'PENDING' + assert resp.output["deployed_model"] == "deploy123456" + assert resp.output["status"] == "PENDING" def test_list_deployment_job(self, http_server): rsp = Deployments.list() assert rsp.status_code == HTTPStatus.OK - assert len(rsp.output['deployments']) == 1 + assert len(rsp.output["deployments"]) == 1 def test_get_deployment_job(self, http_server): rsp = Deployments.get(TEST_JOB_ID) assert rsp.status_code == HTTPStatus.OK - assert rsp.output['deployed_model'] == TEST_JOB_ID - assert rsp.output['status'] == 'PENDING' + assert rsp.output["deployed_model"] == TEST_JOB_ID + assert rsp.output["status"] == "PENDING" def test_delete_deployment_job(self, http_server): rsp = Deployments.delete(TEST_JOB_ID) diff --git a/tests/test_http_files_api.py b/tests/test_http_files_api.py index 31b574c..c405a4b 100644 --- a/tests/test_http_files_api.py +++ b/tests/test_http_files_api.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from http import HTTPStatus @@ -8,30 +9,32 @@ class TestFileRequest(MockRequestBase): def test_upload_files(self, http_server): - resp = Files.upload(file_path='tests/data/dogs.jpg', - purpose='fine-tune', - custom_file_name='gpt3_training.csv') + resp = Files.upload( + file_path="tests/data/dogs.jpg", + purpose="fine-tune", + custom_file_name="gpt3_training.csv", + ) print(resp) assert resp.status_code == HTTPStatus.OK - assert resp.output['uploaded_files'][0]['file_id'] == 'xxxx' + assert resp.output["uploaded_files"][0]["file_id"] == "xxxx" def test_list_files(self, http_server): resp = Files.list() assert resp.status_code == HTTPStatus.OK - assert len(resp.output['files']) == 3 + assert len(resp.output["files"]) == 3 def test_get_file(self, http_server): - resp = Files.get(file_id='111111') + resp = Files.get(file_id="111111") assert resp.status_code == HTTPStatus.OK - assert resp.output['file_id'] == '111111' + assert resp.output["file_id"] == "111111" def test_delete_file(self, http_server): - resp = Files.delete(file_id='111111') + resp = Files.delete(file_id="111111") assert resp.status_code == HTTPStatus.OK - resp = Files.delete(file_id='222222') # not exist + resp = Files.delete(file_id="222222") # not exist assert resp.status_code == HTTPStatus.NOT_FOUND - resp = Files.delete(file_id='333333') # no permission + resp = Files.delete(file_id="333333") # no permission assert resp.status_code == HTTPStatus.FORBIDDEN - resp = Files.delete(file_id='444444', api_key='api-key') # not exist + resp = Files.delete(file_id="444444", api_key="api-key") # not exist assert resp.status_code == HTTPStatus.UNAUTHORIZED diff --git a/tests/test_http_fine_tunes_api.py b/tests/test_http_fine_tunes_api.py index 03f5daf..1fcf26b 100644 --- a/tests/test_http_fine_tunes_api.py +++ b/tests/test_http_fine_tunes_api.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -18,7 +19,8 @@ class TestFineTuneRequest(MockServerBase): @classmethod def setup_class(cls): cls.case_data = json.load( - open('tests/data/fine_tune.json', 'r', encoding='utf-8')) + open('tests/data/fine_tune.json', 'r', encoding='utf-8'), + ) super().setup_class() def test_create_fine_tune_job(self, mock_server: MockServer): @@ -27,13 +29,16 @@ def test_create_fine_tune_job(self, mock_server: MockServer): model = 'gpt' training_file_ids = 'training_001' validation_file_ids = 'validation_001' - hyper_parameters = {'epochs': 10, - 'learning_rate': 0.001 - } - resp = FineTunes.call(model=model, - training_file_ids=training_file_ids, - validation_file_ids=validation_file_ids, - hyper_parameters=hyper_parameters) + hyper_parameters = { + 'epochs': 10, + 'learning_rate': 0.001, + } + resp = FineTunes.call( + model=model, + training_file_ids=training_file_ids, + validation_file_ids=validation_file_ids, + hyper_parameters=hyper_parameters, + ) req = mock_server.requests.get(block=True) assert req['path'] == '/api/v1/fine-tunes' assert req['body']['model'] == model @@ -52,12 +57,14 @@ def test_create_fine_tune_job_with_files(self, mock_server: MockServer): validation_file_ids = ['validation_001', 'validation_002'] hyper_parameters = { 'epochs': 10, - 'learning_rate': 0.001 - } - resp = FineTunes.call(model=model, - training_file_ids=training_file_ids, - validation_file_ids=validation_file_ids, - hyper_parameters=hyper_parameters) + 'learning_rate': 0.001, + } + resp = FineTunes.call( + model=model, + training_file_ids=training_file_ids, + validation_file_ids=validation_file_ids, + hyper_parameters=hyper_parameters, + ) req = mock_server.requests.get(block=True) assert req['path'] == '/api/v1/fine-tunes' assert req['body']['model'] == model @@ -73,8 +80,10 @@ def test_create_fine_tune_job_with_files(self, mock_server: MockServer): def test_list_fine_tune_job(self, mock_server: MockServer): response_body = self.case_data['list_response'] mock_server.responses.put(json.dumps(response_body)) - response = FineTunes.list(page_no=10, - page_size=101) + response = FineTunes.list( + page_no=10, + page_size=101, + ) req = mock_server.requests.get(block=True) assert req['path'] == '/api/v1/fine-tunes?page_no=10&page_size=101' assert len(response.output.jobs) == 2 diff --git a/tests/test_http_models_api.py b/tests/test_http_models_api.py index b4c9861..f175880 100644 --- a/tests/test_http_models_api.py +++ b/tests/test_http_models_api.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from http import HTTPStatus @@ -11,9 +12,9 @@ class TestModelRequest(MockRequestBase): def test_list_models(self, http_server): rsp = Models.list() assert rsp.status_code == HTTPStatus.OK - assert len(rsp.output['models']) == 2 + assert len(rsp.output["models"]) == 2 def test_get_model(self, http_server): rsp = Models.get(TEST_JOB_ID) assert rsp.status_code == HTTPStatus.OK - assert rsp.output['model_id'] == TEST_JOB_ID + assert rsp.output["model_id"] == TEST_JOB_ID diff --git a/tests/test_image_synthesis.py b/tests/test_image_synthesis.py index f024fbf..aa7950b 100644 --- a/tests/test_image_synthesis.py +++ b/tests/test_image_synthesis.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from http import HTTPStatus @@ -13,51 +14,57 @@ def setup_class(cls): super().setup_class() def test_create_task(self): - rsp = ImageSynthesis.call(model=ImageSynthesis.Models.wanx_2_1_imageedit, - function="description_edit_with_mask", - prompt="帮我编辑图片。把所选区域变成黑色", - mask_image_url="https://static.dingtalk.com/media/lQLPD2ob9dfKPBvNBADNAkCwOcPjjaFVcEcHqrO8n1BLAA_576_1024.png_620x10000q90.png", - base_image_url="https://static.dingtalk.com/media/lQLPD2jl0mg85BvNBADNAkCwNJvjWJXBVMwHqrO0OvZlAA_576_1024.png_620x10000q90.png", - n=1) + rsp = ImageSynthesis.call( + model=ImageSynthesis.Models.wanx_2_1_imageedit, + function="description_edit_with_mask", + prompt="帮我编辑图片。把所选区域变成黑色", + mask_image_url="https://static.dingtalk.com/media/lQLPD2ob9dfKPBvNBADNAkCwOcPjjaFVcEcHqrO8n1BLAA_576_1024.png_620x10000q90.png", + base_image_url="https://static.dingtalk.com/media/lQLPD2jl0mg85BvNBADNAkCwNJvjWJXBVMwHqrO0OvZlAA_576_1024.png_620x10000q90.png", + n=1, + ) assert rsp.status_code == HTTPStatus.OK - assert rsp.output['task_status'] == 'SUCCEEDED' - assert len(rsp.output['results']) == 1 + assert rsp.output["task_status"] == "SUCCEEDED" + assert len(rsp.output["results"]) == 1 def test_fetch_status(self): - rsp = ImageSynthesis.call(model=ImageSynthesis.Models.wanx_2_1_imageedit, - function="description_edit_with_mask", - prompt="帮我编辑图片。把所选区域变成黑色", - mask_image_url="https://static.dingtalk.com/media/lQLPD2ob9dfKPBvNBADNAkCwOcPjjaFVcEcHqrO8n1BLAA_576_1024.png_620x10000q90.png", - base_image_url="https://static.dingtalk.com/media/lQLPD2jl0mg85BvNBADNAkCwNJvjWJXBVMwHqrO0OvZlAA_576_1024.png_620x10000q90.png", - n=1) + rsp = ImageSynthesis.call( + model=ImageSynthesis.Models.wanx_2_1_imageedit, + function="description_edit_with_mask", + prompt="帮我编辑图片。把所选区域变成黑色", + mask_image_url="https://static.dingtalk.com/media/lQLPD2ob9dfKPBvNBADNAkCwOcPjjaFVcEcHqrO8n1BLAA_576_1024.png_620x10000q90.png", + base_image_url="https://static.dingtalk.com/media/lQLPD2jl0mg85BvNBADNAkCwNJvjWJXBVMwHqrO0OvZlAA_576_1024.png_620x10000q90.png", + n=1, + ) assert rsp.status_code == HTTPStatus.OK rsp = ImageSynthesis.fetch(rsp) assert rsp.status_code == HTTPStatus.OK def test_wait(self): - rsp = ImageSynthesis.async_call(model=ImageSynthesis.Models.wanx_2_1_imageedit, - function="description_edit_with_mask", - prompt="帮我编辑图片。把所选区域变成黑色", - mask_image_url="https://static.dingtalk.com/media/lQLPD2ob9dfKPBvNBADNAkCwOcPjjaFVcEcHqrO8n1BLAA_576_1024.png_620x10000q90.png", - base_image_url="https://static.dingtalk.com/media/lQLPD2jl0mg85BvNBADNAkCwNJvjWJXBVMwHqrO0OvZlAA_576_1024.png_620x10000q90.png", - n=1) + rsp = ImageSynthesis.async_call( + model=ImageSynthesis.Models.wanx_2_1_imageedit, + function="description_edit_with_mask", + prompt="帮我编辑图片。把所选区域变成黑色", + mask_image_url="https://static.dingtalk.com/media/lQLPD2ob9dfKPBvNBADNAkCwOcPjjaFVcEcHqrO8n1BLAA_576_1024.png_620x10000q90.png", + base_image_url="https://static.dingtalk.com/media/lQLPD2jl0mg85BvNBADNAkCwNJvjWJXBVMwHqrO0OvZlAA_576_1024.png_620x10000q90.png", + n=1, + ) assert rsp.status_code == HTTPStatus.OK rsp = ImageSynthesis.wait(rsp) assert rsp.status_code == HTTPStatus.OK - assert rsp.output.task_id != '' # verify access by properties. + assert rsp.output.task_id != "" # verify access by properties. assert rsp.output.task_status == TaskStatus.SUCCEEDED assert len(rsp.output.results) == 1 - assert rsp.output.results[0].url != '' + assert rsp.output.results[0].url != "" - assert rsp.output['task_id'] != '' - assert rsp.output['task_status'] == TaskStatus.SUCCEEDED - assert len(rsp.output['results']) == 1 - assert rsp.output['results'][0]['url'] != '' + assert rsp.output["task_id"] != "" + assert rsp.output["task_status"] == TaskStatus.SUCCEEDED + assert len(rsp.output["results"]) == 1 + assert rsp.output["results"][0]["url"] != "" def test_list_cancel_task(self): - rsp = ImageSynthesis.list(status='CANCELED') + rsp = ImageSynthesis.list(status="CANCELED") assert rsp.status_code == HTTPStatus.OK def test_list_all(self): diff --git a/tests/test_messages.py b/tests/test_messages.py index 514b8e6..6b1d8be 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -10,112 +11,131 @@ class TestMessages(MockServerBase): - TEST_MODEL_NAME = 'test_model' - ASSISTANT_ID = 'asst_42bff274-6d44-45b8-90b1-11dd14534499' + TEST_MODEL_NAME = "test_model" + ASSISTANT_ID = "asst_42bff274-6d44-45b8-90b1-11dd14534499" case_data = None @classmethod def setup_class(cls): cls.case_data = json.load( - open('tests/data/messages.json', 'r', encoding='utf-8')) + open("tests/data/messages.json", "r", encoding="utf-8"), + ) super().setup_class() def test_create(self, mock_server: MockServer): - request_body = self.case_data['create_message_request'] - response_body = self.case_data['create_message_response'] + request_body = self.case_data["create_message_request"] + response_body = self.case_data["create_message_response"] mock_server.responses.put(json.dumps(response_body)) response = Messages.create(**request_body) req = mock_server.requests.get(block=True) - assert req['role'] == 'user' - assert req['content'] == response.content[0].text.value - assert response.thread_id == response_body['thread_id'] + assert req["role"] == "user" + assert req["content"] == response.content[0].text.value + assert response.thread_id == response_body["thread_id"] assert len(response.content) == 1 def test_update(self, mock_server: MockServer): - response_body = self.case_data['create_message_response'] + response_body = self.case_data["create_message_response"] mock_server.responses.put(json.dumps(response_body)) thread_id = str(uuid.uuid4()) message_id = str(uuid.uuid4()) - metadata = {'key': 'value'} - response = Messages.update(message_id, - thread_id=thread_id, - metadata=metadata, - workspace='111') + metadata = {"key": "value"} + response = Messages.update( + message_id, + thread_id=thread_id, + metadata=metadata, + workspace="111", + ) req = mock_server.requests.get(block=True) - assert req['body']['metadata'] == metadata - assert req[ - 'path'] == f'/api/v1/threads/{thread_id}/messages/{message_id}' - assert req['headers']['X-DashScope-WorkSpace'] == '111' - assert response.thread_id == response_body['thread_id'] + assert req["body"]["metadata"] == metadata + assert ( + req["path"] == f"/api/v1/threads/{thread_id}/messages/{message_id}" + ) + assert req["headers"]["X-DashScope-WorkSpace"] == "111" + assert response.thread_id == response_body["thread_id"] assert len(response.content) == 1 def test_retrieve(self, mock_server: MockServer): - response_obj = self.case_data['create_message_response'] + response_obj = self.case_data["create_message_response"] response_str = json.dumps(response_obj) mock_server.responses.put(response_str) - thread_id = 'tid' - message_id = 'mid' + thread_id = "tid" + message_id = "mid" response = Messages.retrieve(message_id, thread_id=thread_id) # get assistant id we send. path = mock_server.requests.get(block=True) - assert path == f'/api/v1/threads/{thread_id}/messages/{message_id}' - assert response.thread_id == response_obj['thread_id'] + assert path == f"/api/v1/threads/{thread_id}/messages/{message_id}" + assert response.thread_id == response_obj["thread_id"] assert len(response.content) == 1 def test_list(self, mock_server: MockServer): - response_obj = self.case_data['list_message_response'] + response_obj = self.case_data["list_message_response"] mock_server.responses.put(json.dumps(response_obj)) - thread_id = 'test_thread_id' - response = Messages.list(thread_id, - limit=10, - order='inc', - after='after', - before='before') + thread_id = "test_thread_id" + response = Messages.list( + thread_id, + limit=10, + order="inc", + after="after", + before="before", + ) # get assistant id we send. req = mock_server.requests.get(block=True) - assert req == f'/api/v1/threads/{thread_id}/messages?limit=10&order=inc&after=after&before=before' + assert ( + req + == f"/api/v1/threads/{thread_id}/messages?limit=10&order=inc&after=after&before=before" + ) assert len(response.data) == 2 - assert response.data[0].id == 'msg_1' - assert response.data[1].id == 'msg_0' + assert response.data[0].id == "msg_1" + assert response.data[1].id == "msg_0" def test_list_message_files(self, mock_server: MockServer): - response_obj = self.case_data['list_message_files_response'] + response_obj = self.case_data["list_message_files_response"] mock_server.responses.put(json.dumps(response_obj)) - thread_id = 'test_thread_id' - message_id = 'test_message_id' - response = Files.list(message_id, - thread_id=thread_id, - limit=10, - order='inc', - after='after', - before='before') + thread_id = "test_thread_id" + message_id = "test_message_id" + response = Files.list( + message_id, + thread_id=thread_id, + limit=10, + order="inc", + after="after", + before="before", + ) # get assistant id we send. req = mock_server.requests.get(block=True) - assert req == f'/api/v1/threads/{thread_id}/messages/{message_id}/files?limit=10&order=inc&after=after&before=before' # noqa E501 + assert ( + req + == f"/api/v1/threads/{thread_id}/messages/{message_id}/files?limit=10&order=inc&after=after&before=before" + ) # noqa E501 assert len(response.data) == 2 - assert response.data[0].id == 'file-1' - assert response.data[1].id == 'file-2' + assert response.data[0].id == "file-1" + assert response.data[1].id == "file-2" def test_retrieve_message_file(self, mock_server: MockServer): file_id = str(uuid.uuid4()) message_id = str(uuid.uuid4()) response_obj = { - 'id': file_id, - 'object': 'thread.message.file', - 'created_at': 11111111, - 'message_id': message_id + "id": file_id, + "object": "thread.message.file", + "created_at": 11111111, + "message_id": message_id, } mock_server.responses.put(json.dumps(response_obj)) - thread_id = 'test_thread_id' - response = Files.retrieve(file_id, - thread_id=thread_id, - message_id=message_id, - limit=10, - order='inc', - after='after', - before='before') + thread_id = "test_thread_id" + response = Files.retrieve( + file_id, + thread_id=thread_id, + message_id=message_id, + limit=10, + order="inc", + after="after", + before="before", + ) # get assistant id we send. req = mock_server.requests.get(block=True) - assert req == f'/api/v1/threads/{thread_id}/messages/{message_id}/files/{file_id}' + assert ( + req + == f"/api/v1/threads/{thread_id}/messages/{message_id}/files/{file_id}" + ) assert response.id == file_id assert response.message_id == message_id diff --git a/tests/test_multimodal_dialog.py b/tests/test_multimodal_dialog.py index 166e131..e919c8c 100644 --- a/tests/test_multimodal_dialog.py +++ b/tests/test_multimodal_dialog.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import sys import pytest @@ -5,12 +6,21 @@ import time from dashscope.common.logging import logger from dashscope.multimodal.dialog_state import DialogState -from dashscope.multimodal.multimodal_dialog import MultiModalDialog, MultiModalCallback -from dashscope.multimodal.multimodal_request_params import Upstream, Downstream, ClientInfo, RequestParameters, Device, \ - RequestToRespondParameters +from dashscope.multimodal.multimodal_dialog import ( + MultiModalDialog, + MultiModalCallback, +) +from dashscope.multimodal.multimodal_request_params import ( + Upstream, + Downstream, + ClientInfo, + RequestParameters, + Device, + RequestToRespondParameters, +) from tests.base_test import BaseTestEnvironment -logger = logging.getLogger('dashscope') +logger = logging.getLogger("dashscope") logger.setLevel(logging.DEBUG) # create console handler and set level to debug console_handler = logging.StreamHandler() @@ -76,40 +86,63 @@ def on_request_accepted(self): return def on_close(self, close_status_code, close_msg): - logger.info("close with status code: %d, msg: %s" % (close_status_code, close_msg)) + logger.info( + "close with status code: %d, msg: %s" + % (close_status_code, close_msg), + ) class TestMultiModalDialog(BaseTestEnvironment): @classmethod def setup_class(cls): super().setup_class() - cls.model = 'multimodal-dialog' - cls.voice = 'longxiaochun_v2' + cls.model = "multimodal-dialog" + cls.voice = "longxiaochun_v2" @pytest.mark.skip def test_multimodal_dialog_one_turn(self): # 对话状态 Listening->Thinking->Responding->Listening - up_stream = Upstream(type="AudioOnly", mode="push2talk", audio_format="pcm") + up_stream = Upstream( + type="AudioOnly", + mode="push2talk", + audio_format="pcm", + ) down_stream = Downstream(voice=self.voice, sample_rate=16000) - client_info = ClientInfo(user_id="aabb", device=Device(uuid="1234567890")) - request_params = RequestParameters(upstream=up_stream, downstream=down_stream, - client_info=client_info) + client_info = ClientInfo( + user_id="aabb", + device=Device(uuid="1234567890"), + ) + request_params = RequestParameters( + upstream=up_stream, + downstream=down_stream, + client_info=client_info, + ) self.callback = TestCallback() - self.conversation = MultiModalDialog(app_id="", - workspace_id="llm-xxxx", - url="wss://poc-dashscope.aliyuncs.com/api-ws/v1/inference", - request_params=request_params, - multimodal_callback=self.callback, - model=self.model) + self.conversation = MultiModalDialog( + app_id="", + workspace_id="llm-xxxx", + url="wss://poc-dashscope.aliyuncs.com/api-ws/v1/inference", + request_params=request_params, + multimodal_callback=self.callback, + model=self.model, + ) self.conversation.start("") # 首轮进入Listening - while DialogState.LISTENING is not self.conversation.get_dialog_state(): + while ( + DialogState.LISTENING is not self.conversation.get_dialog_state() + ): time.sleep(0.1) - self.conversation.request_to_respond("prompt", "今天天气不错", parameters=None) + self.conversation.request_to_respond( + "prompt", + "今天天气不错", + parameters=None, + ) # 等待第二轮Listening - while DialogState.LISTENING is not self.conversation.get_dialog_state(): + while ( + DialogState.LISTENING is not self.conversation.get_dialog_state() + ): time.sleep(0.1) self.conversation.stop() diff --git a/tests/test_rerank.py b/tests/test_rerank.py index 25e553a..bf0890e 100644 --- a/tests/test_rerank.py +++ b/tests/test_rerank.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -11,28 +12,28 @@ class TestReRank(MockServerBase): def test_call(self, mock_server: MockServer): response_body = { - 'output': { - 'results': [ + "output": { + "results": [ { - 'index': 1, - 'relevance_score': 0.987654, - 'document': { # 如果return_documents=true - 'text': '哈尔滨是中国黑龙江省的省会,位于中国东北' - } + "index": 1, + "relevance_score": 0.987654, + "document": { # 如果return_documents=true + "text": "哈尔滨是中国黑龙江省的省会,位于中国东北", + }, }, { - 'index': 0, - 'relevance_score': 0.876543, - 'document': { # 如果return_documents=true - 'text': '黑龙江离俄罗斯很近' - } - } - ] + "index": 0, + "relevance_score": 0.876543, + "document": { # 如果return_documents=true + "text": "黑龙江离俄罗斯很近", + }, + }, + ], }, - 'usage': { - 'input_tokens': 1279 + "usage": { + "input_tokens": 1279, }, - 'request_id': 'b042e72d-7994-97dd-b3d2-7ee7e0140525' + "request_id": "b042e72d-7994-97dd-b3d2-7ee7e0140525", } mock_server.responses.put(json.dumps(response_body)) model = str(uuid.uuid4()) @@ -41,21 +42,23 @@ def test_call(self, mock_server: MockServer): str(uuid.uuid4()), str(uuid.uuid4()), str(uuid.uuid4()), - str(uuid.uuid4()) + str(uuid.uuid4()), ] - response = TextReRank.call(model=model, - query=query, - documents=documents, - return_documents=False, - top_n=10) + response = TextReRank.call( + model=model, + query=query, + documents=documents, + return_documents=False, + top_n=10, + ) req = mock_server.requests.get(block=True) - assert req['path'] == '/api/v1/services/rerank/text-rerank/text-rerank' - assert req['body']['parameters'] == { - 'return_documents': False, - 'top_n': 10 + assert req["path"] == "/api/v1/services/rerank/text-rerank/text-rerank" + assert req["body"]["parameters"] == { + "return_documents": False, + "top_n": 10, } - assert req['body']['input'] == {'query': query, 'documents': documents} - assert response.usage['input_tokens'] == 1279 - assert len(response.output['results']) == 2 - assert response.output['results'][0]['index'] == 1 - assert response.output['results'][1]['document']['text'] == '黑龙江离俄罗斯很近' + assert req["body"]["input"] == {"query": query, "documents": documents} + assert response.usage["input_tokens"] == 1279 + assert len(response.output["results"]) == 2 + assert response.output["results"][0]["index"] == 1 + assert response.output["results"][1]["document"]["text"] == "黑龙江离俄罗斯很近" diff --git a/tests/test_runs.py b/tests/test_runs.py index 80f0173..845edf4 100644 --- a/tests/test_runs.py +++ b/tests/test_runs.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -9,149 +10,167 @@ class TestRuns(MockServerBase): - TEST_MODEL_NAME = 'test_model' - ASSISTANT_ID = 'asst_42bff274-6d44-45b8-90b1-11dd14534499' + TEST_MODEL_NAME = "test_model" + ASSISTANT_ID = "asst_42bff274-6d44-45b8-90b1-11dd14534499" case_data = None @classmethod def setup_class(cls): cls.case_data = json.load( - open('tests/data/runs.json', 'r', encoding='utf-8')) + open("tests/data/runs.json", "r", encoding="utf-8"), + ) super().setup_class() def test_create_simple(self, mock_server: MockServer): - response_body = self.case_data['create_run_response'] + response_body = self.case_data["create_run_response"] mock_server.responses.put(json.dumps(response_body)) thread_id = str(uuid.uuid4()) assistant_id = str(uuid.uuid4()) response = Runs.create(thread_id, assistant_id=assistant_id) req = mock_server.requests.get(block=True) - assert req['assistant_id'] == assistant_id - assert response.thread_id == response_body['thread_id'] - assert response.metadata == {'key': 'value'} + assert req["assistant_id"] == assistant_id + assert response.thread_id == response_body["thread_id"] + assert response.metadata == {"key": "value"} def test_create_complicated(self, mock_server: MockServer): - response_body = self.case_data['create_run_response'] + response_body = self.case_data["create_run_response"] mock_server.responses.put(json.dumps(response_body)) thread_id = str(uuid.uuid4()) assistant_id = str(uuid.uuid4()) model_name = str(uuid.uuid4()) - instructions = 'Your a tool.' - additional_instructions = 'additional_instructions' - tools = [{ - 'type': 'code_interpreter' - }, { - 'type': 'search' - }, { - 'type': 'function', - 'function': { - 'name': 'big_add', - 'description': 'Add to number', - 'parameters': { - 'type': 'object', - 'properties': { - 'left': { - 'type': 'integer', - 'description': 'The left operator' + instructions = "Your a tool." + additional_instructions = "additional_instructions" + tools = [ + { + "type": "code_interpreter", + }, + { + "type": "search", + }, + { + "type": "function", + "function": { + "name": "big_add", + "description": "Add to number", + "parameters": { + "type": "object", + "properties": { + "left": { + "type": "integer", + "description": "The left operator", + }, + "right": { + "type": "integer", + "description": "The right operator.", + }, }, - 'right': { - 'type': 'integer', - 'description': 'The right operator.' - } + "required": ["left", "right"], }, - 'required': ['left', 'right'] - } - } - }] - metadata = {'key': 'meta'} - response = Runs.create(thread_id, - assistant_id=assistant_id, - model=model_name, - instructions=instructions, - additional_instructions=additional_instructions, - tools=tools, - metadata=metadata) + }, + }, + ] + metadata = {"key": "meta"} + response = Runs.create( + thread_id, + assistant_id=assistant_id, + model=model_name, + instructions=instructions, + additional_instructions=additional_instructions, + tools=tools, + metadata=metadata, + ) req = mock_server.requests.get(block=True) - assert req['assistant_id'] == assistant_id - assert req['model'] == model_name - assert req['instructions'] == instructions - assert req['additional_instructions'] == additional_instructions - assert req['metadata'] == metadata - assert req['tools'] == tools - assert response.thread_id == response_body['thread_id'] + assert req["assistant_id"] == assistant_id + assert req["model"] == model_name + assert req["instructions"] == instructions + assert req["additional_instructions"] == additional_instructions + assert req["metadata"] == metadata + assert req["tools"] == tools + assert response.thread_id == response_body["thread_id"] assert len(response.tools) == 3 - assert response.tools[0].type == 'code_interpreter' + assert response.tools[0].type == "code_interpreter" def test_retrieve(self, mock_server: MockServer): - response_obj = self.case_data['create_run_response'] + response_obj = self.case_data["create_run_response"] response_str = json.dumps(response_obj) mock_server.responses.put(response_str) - thread_id = 'tid' + thread_id = "tid" run_id = str(uuid.uuid4()) response = Runs.retrieve(run_id, thread_id=thread_id) # get assistant id we send. path = mock_server.requests.get(block=True) - assert path == f'/api/v1/threads/{thread_id}/runs/{run_id}' + assert path == f"/api/v1/threads/{thread_id}/runs/{run_id}" assert len(response.tools) == 3 - assert response.tools[0].type == 'code_interpreter' + assert response.tools[0].type == "code_interpreter" def test_list(self, mock_server: MockServer): - response_obj = self.case_data['list_run_response'] + response_obj = self.case_data["list_run_response"] mock_server.responses.put(json.dumps(response_obj)) - thread_id = 'test_thread_id' - response = Runs.list(thread_id, - limit=10, - order='inc', - after='after', - before='before') + thread_id = "test_thread_id" + response = Runs.list( + thread_id, + limit=10, + order="inc", + after="after", + before="before", + ) # get assistant id we send. req = mock_server.requests.get(block=True) - assert req == f'/api/v1/threads/{thread_id}/runs?limit=10&order=inc&after=after&before=before' + assert ( + req + == f"/api/v1/threads/{thread_id}/runs?limit=10&order=inc&after=after&before=before" + ) assert len(response.data) == 1 - assert response.data[0].id == '1' - assert response.data[0].tools[2].type == 'function' + assert response.data[0].id == "1" + assert response.data[0].tools[2].type == "function" def test_create_thread_and_run(self, mock_server: MockServer): - response_body = self.case_data['create_run_response'] + response_body = self.case_data["create_run_response"] mock_server.responses.put(json.dumps(response_body)) assistant_id = str(uuid.uuid4()) model_name = str(uuid.uuid4()) - instructions = 'Your a tool.' - additional_instructions = 'additional_instructions' - tools = [{ - 'type': 'code_interpreter' - }, { - 'type': 'search' - }, { - 'type': 'function', - 'function': { - 'name': 'big_add', - 'description': 'Add to number', - 'parameters': { - 'type': 'object', - 'properties': { - 'left': { - 'type': 'integer', - 'description': 'The left operator' + instructions = "Your a tool." + additional_instructions = "additional_instructions" + tools = [ + { + "type": "code_interpreter", + }, + { + "type": "search", + }, + { + "type": "function", + "function": { + "name": "big_add", + "description": "Add to number", + "parameters": { + "type": "object", + "properties": { + "left": { + "type": "integer", + "description": "The left operator", + }, + "right": { + "type": "integer", + "description": "The right operator.", + }, }, - 'right': { - 'type': 'integer', - 'description': 'The right operator.' - } + "required": ["left", "right"], }, - 'required': ['left', 'right'] - } - } - }] - metadata = {'key': 'meta'} + }, + }, + ] + metadata = {"key": "meta"} thread = { - 'messages': [{ - 'role': 'user', - 'content': 'Test content' - }], - 'metadata': { - 'key': 'meta' - } + "messages": [ + { + "role": "user", + "content": "Test content", + }, + ], + "metadata": { + "key": "meta", + }, } # process by handle_update_object_with_post, response = Runs.create_thread_and_run( @@ -161,121 +180,148 @@ def test_create_thread_and_run(self, mock_server: MockServer): instructions=instructions, additional_instructions=additional_instructions, tools=tools, - metadata=metadata) + metadata=metadata, + ) req = mock_server.requests.get(block=True) - assert req['assistant_id'] == assistant_id - assert req['model'] == model_name - assert req['instructions'] == instructions - assert req['additional_instructions'] == additional_instructions - assert req['metadata'] == metadata - assert req['tools'] == tools - assert req['thread'] == thread - assert response.thread_id == response_body['thread_id'] + assert req["assistant_id"] == assistant_id + assert req["model"] == model_name + assert req["instructions"] == instructions + assert req["additional_instructions"] == additional_instructions + assert req["metadata"] == metadata + assert req["tools"] == tools + assert req["thread"] == thread + assert response.thread_id == response_body["thread_id"] assert len(response.tools) == 3 - assert response.tools[0].type == 'code_interpreter' + assert response.tools[0].type == "code_interpreter" def test_submit_tool_outputs(self, mock_server: MockServer): - response_body = self.case_data['submit_function_call_result'] + response_body = self.case_data["submit_function_call_result"] mock_server.responses.put(json.dumps(response_body)) thread_id = str(uuid.uuid4()) run_id = str(uuid.uuid4()) - tool_outputs = [{ - 'output': '789076524', - 'tool_call_id': 'call_DqGuSZ1NtWimgQcj8tGph6So' - }] - response = Runs.submit_tool_outputs(thread_id=thread_id, - run_id=run_id, - tool_outputs=tool_outputs) + tool_outputs = [ + { + "output": "789076524", + "tool_call_id": "call_DqGuSZ1NtWimgQcj8tGph6So", + }, + ] + response = Runs.submit_tool_outputs( + thread_id=thread_id, + run_id=run_id, + tool_outputs=tool_outputs, + ) req = mock_server.requests.get(block=True) - assert req['tool_outputs'] == tool_outputs - assert response.thread_id == response_body['thread_id'] + assert req["tool_outputs"] == tool_outputs + assert response.thread_id == response_body["thread_id"] assert len(response.tools) == 3 - assert response.tools[0].type == 'code_interpreter' + assert response.tools[0].type == "code_interpreter" def test_run_required_function_call(self, mock_server: MockServer): - response_obj = self.case_data['required_action_function_call_response'] + response_obj = self.case_data["required_action_function_call_response"] mock_server.responses.put(json.dumps(response_obj)) thread_id = str(uuid.uuid4()) assistant_id = str(uuid.uuid4()) response = Runs.create(thread_id, assistant_id=assistant_id) # how to dump response to json. - s = json.dumps(response, - default=lambda o: o.__dict__, - sort_keys=True, - indent=4) + s = json.dumps( + response, + default=lambda o: o.__dict__, + sort_keys=True, + indent=4, + ) print(s) req = mock_server.requests.get(block=True) - assert req['assistant_id'] == assistant_id - assert response.required_action.submit_tool_outputs.tool_calls[ - 0].id == 'call_1' + assert req["assistant_id"] == assistant_id + assert ( + response.required_action.submit_tool_outputs.tool_calls[0].id + == "call_1" + ) def test_list_run_steps(self, mock_server: MockServer): - response_obj = self.case_data['list_run_steps_response'] + response_obj = self.case_data["list_run_steps_response"] mock_server.responses.put(json.dumps(response_obj)) - thread_id = 'test_thread_id' + thread_id = "test_thread_id" run_id = str(uuid.uuid4()) - response = Steps.list(run_id, - thread_id=thread_id, - limit=10, - order='inc', - after='after', - before='before') + response = Steps.list( + run_id, + thread_id=thread_id, + limit=10, + order="inc", + after="after", + before="before", + ) # get assistant id we send. req = mock_server.requests.get(block=True) - assert req == f'/api/v1/threads/{thread_id}/runs/{run_id}/steps?limit=10&order=inc&after=after&before=before' + assert ( + req + == f"/api/v1/threads/{thread_id}/runs/{run_id}/steps?limit=10&order=inc&after=after&before=before" + ) assert len(response.data) == 2 - assert response.data[0].id == 'step_1' - assert response.data[0].step_details.type == 'message_creation' - assert response.data[ - 0].step_details.message_creation.message_id == 'msg_1' + assert response.data[0].id == "step_1" + assert response.data[0].step_details.type == "message_creation" + assert ( + response.data[0].step_details.message_creation.message_id + == "msg_1" + ) assert response.data[0].usage.completion_tokens == 25 assert response.data[0].usage.prompt_tokens == 809 assert response.data[0].usage.total_tokens == 834 - assert response.data[1].id == 'step_2' - assert response.data[1].step_details.type == 'tool_calls' - assert response.data[1].step_details.tool_calls[0].type == 'function' - assert response.data[1].step_details.tool_calls[0].id == 'call_1' - assert response.data[1].step_details.tool_calls[ - 0].function.arguments == "{\"left\":87787,\"right\":788988737}" - assert response.data[1].step_details.tool_calls[ - 0].function.output == '789076524' - assert response.data[1].step_details.tool_calls[ - 0].function.name == 'big_add' + assert response.data[1].id == "step_2" + assert response.data[1].step_details.type == "tool_calls" + assert response.data[1].step_details.tool_calls[0].type == "function" + assert response.data[1].step_details.tool_calls[0].id == "call_1" + assert ( + response.data[1].step_details.tool_calls[0].function.arguments + == '{"left":87787,"right":788988737}' + ) + assert ( + response.data[1].step_details.tool_calls[0].function.output + == "789076524" + ) + assert ( + response.data[1].step_details.tool_calls[0].function.name + == "big_add" + ) def test_retrieve_run_steps(self, mock_server: MockServer): - response_obj = self.case_data['retrieve_run_step'] + response_obj = self.case_data["retrieve_run_step"] mock_server.responses.put(json.dumps(response_obj)) thread_id = str(uuid.uuid4()) run_id = str(uuid.uuid4()) step_id = str(uuid.uuid4()) - response = Steps.retrieve(step_id, - thread_id=thread_id, - run_id=run_id, - limit=10, - order='inc', - after='after', - before='before') + response = Steps.retrieve( + step_id, + thread_id=thread_id, + run_id=run_id, + limit=10, + order="inc", + after="after", + before="before", + ) # get assistant id we send. req = mock_server.requests.get(block=True) - assert req == f'/api/v1/threads/{thread_id}/runs/{run_id}/steps/{step_id}' - assert response.id == 'step_1' + assert ( + req == f"/api/v1/threads/{thread_id}/runs/{run_id}/steps/{step_id}" + ) + assert response.id == "step_1" - assert response.step_details.type == 'tool_calls' - assert response.step_details.tool_calls[0].id == 'call_1' - assert response.step_details.tool_calls[0].function.name == 'big_add' - assert response.step_details.tool_calls[ - 0].function.output == '789076524' + assert response.step_details.type == "tool_calls" + assert response.step_details.tool_calls[0].id == "call_1" + assert response.step_details.tool_calls[0].function.name == "big_add" + assert ( + response.step_details.tool_calls[0].function.output == "789076524" + ) assert response.usage.total_tokens == 798 assert response.usage.prompt_tokens == 776 assert response.usage.completion_tokens == 22 def test_cancel(self, mock_server: MockServer): - response_obj = self.case_data['retrieve_run_step'] + response_obj = self.case_data["retrieve_run_step"] mock_server.responses.put(json.dumps(response_obj)) thread_id = str(uuid.uuid4()) run_id = str(uuid.uuid4()) response = Runs.cancel(run_id, thread_id=thread_id) # get assistant id we send. req = mock_server.requests.get(block=True) - assert req == f'/api/v1/threads/{thread_id}/runs/{run_id}/cancel' - assert response.id == 'step_1' + assert req == f"/api/v1/threads/{thread_id}/runs/{run_id}/cancel" + assert response.id == "step_1" diff --git a/tests/test_sketch_image_synthesis.py b/tests/test_sketch_image_synthesis.py index 0f4fee2..3d2c814 100644 --- a/tests/test_sketch_image_synthesis.py +++ b/tests/test_sketch_image_synthesis.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -10,57 +11,61 @@ class TestSketchImageSynthesis(MockServerBase): text_response_obj = { - 'status_code': 200, - 'request_id': 'effd2cb1-1a8c-9f18-9a49-a396f673bd40', - 'code': '', - 'message': '', - 'output': { - 'task_id': 'hello', - 'task_status': 'SUCCEEDED', - 'results': [{ - 'url': 'url1' - }] + "status_code": 200, + "request_id": "effd2cb1-1a8c-9f18-9a49-a396f673bd40", + "code": "", + "message": "", + "output": { + "task_id": "hello", + "task_status": "SUCCEEDED", + "results": [ + { + "url": "url1", + }, + ], + }, + "usage": { + "image_count": 4, }, - 'usage': { - 'image_count': 4, - } } def test_with_all_parameters(self, mock_server: MockServer): response_str = json.dumps(TestSketchImageSynthesis.text_response_obj) mock_server.responses.put(response_str) - prompt = 'hello' + prompt = "hello" response = ImageSynthesis.async_call( model=ImageSynthesis.Models.wanx_sketch_to_image_v1, prompt=prompt, - sketch_image_url='http://sketch_url', + sketch_image_url="http://sketch_url", n=4, - size='1024*1024', + size="1024*1024", sketch_weight=8, - realisticness=9) + realisticness=9, + ) req = mock_server.requests.get(block=True) expect_req_str = '{"model": "wanx-sketch-to-image-v1", "parameters": {"n": 4, "size": "1024*1024", "sketch_weight": 8, "realisticness": 9}, "input": {"prompt": "hello", "sketch_image_url": "http://sketch_url"}}' # noqa E501 expect_req = json.loads(expect_req_str) assert expect_req == req assert response.status_code == HTTPStatus.OK - assert response.output.results[0]['url'] == 'url1' + assert response.output.results[0]["url"] == "url1" def test_with_not_all_parameters(self, mock_server: MockServer): response_str = json.dumps(TestSketchImageSynthesis.text_response_obj) mock_server.responses.put(response_str) - prompt = 'hello' + prompt = "hello" response = ImageSynthesis.async_call( model=ImageSynthesis.Models.wanx_sketch_to_image_v1, prompt=prompt, - sketch_image_url='http://sketch_url', + sketch_image_url="http://sketch_url", n=4, - size='1024*1024', - realisticness=9) + size="1024*1024", + realisticness=9, + ) req = mock_server.requests.get(block=True) expect_req_str = '{"model": "wanx-sketch-to-image-v1", "parameters": {"n": 4, "size": "1024*1024", "realisticness": 9}, "input": {"prompt": "hello", "sketch_image_url": "http://sketch_url"}}' # noqa E501 expect_req = json.loads(expect_req_str) assert expect_req == req assert response.status_code == HTTPStatus.OK - assert response.output.results[0]['url'] == 'url1' + assert response.output.results[0]["url"] == "url1" diff --git a/tests/test_speech_recognizer.py b/tests/test_speech_recognizer.py index 775d24f..7c081ee 100644 --- a/tests/test_speech_recognizer.py +++ b/tests/test_speech_recognizer.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import argparse @@ -10,11 +11,14 @@ import pytest import dashscope -from dashscope.audio.asr import (Recognition, RecognitionCallback, - RecognitionResult) +from dashscope.audio.asr import ( + Recognition, + RecognitionCallback, + RecognitionResult, +) from tests.base_test import BaseTestEnvironment -logger = logging.getLogger('dashscope') +logger = logging.getLogger("dashscope") logger.setLevel(logging.DEBUG) # create console handler and set level to debug console_handler = logging.StreamHandler() @@ -22,7 +26,8 @@ # create formatter formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + "%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) # add formatter to ch console_handler.setFormatter(formatter) @@ -42,16 +47,18 @@ class TestSpeechRecognition(BaseTestEnvironment): @classmethod def setup_class(cls): super().setup_class() - cls.model = 'paraformer-realtime-v1' - cls.format = 'pcm' + cls.model = "paraformer-realtime-v1" + cls.format = "pcm" cls.sample_rate = 16000 - cls.file = './tests/data/asr_example.wav' + cls.file = "./tests/data/asr_example.wav" def test_sync_call_with_file(self): - recognition = Recognition(model=self.model, - format=self.format, - sample_rate=self.sample_rate, - callback=None) + recognition = Recognition( + model=self.model, + format=self.format, + sample_rate=self.sample_rate, + callback=None, + ) result = recognition.call(self.file) assert result is not None assert result.get_sentence() is not None @@ -59,12 +66,14 @@ def test_sync_call_with_file(self): def test_async_start_with_stream(self): callback = TestCallback() - recognition = Recognition(model=self.model, - format=self.format, - sample_rate=self.sample_rate, - callback=callback) + recognition = Recognition( + model=self.model, + format=self.format, + sample_rate=self.sample_rate, + callback=callback, + ) recognition.start() - f = open(self.file, 'rb') + f = open(self.file, "rb") while True: chunk = f.read(3200) if not chunk: @@ -77,70 +86,83 @@ def test_async_start_with_stream(self): class Callback(RecognitionCallback): def on_open(self) -> None: - print('RecognitionCallback open.') + print("RecognitionCallback open.") def on_complete(self) -> None: - print('RecognitionCallback complete.') + print("RecognitionCallback complete.") def on_error(self, result: RecognitionResult) -> None: - print('RecognitionCallback task_id: ', result.request_id) - print('RecognitionCallback error: ', result.message) + print("RecognitionCallback task_id: ", result.request_id) + print("RecognitionCallback error: ", result.message) def on_close(self) -> None: - print('RecognitionCallback close.') + print("RecognitionCallback close.") def on_event(self, result: RecognitionResult) -> None: sentence = result.get_sentence() - if 'text' in sentence: - print('RecognitionCallback text: ', sentence['text']) + if "text" in sentence: + print("RecognitionCallback text: ", sentence["text"]) if RecognitionResult.is_sentence_end(sentence): print( - 'RecognitionCallback sentence end, request_id:%s, usage:%s' - % (result.get_request_id(), result.get_usage(sentence))) + "RecognitionCallback sentence end, request_id:%s, usage:%s" + % (result.get_request_id(), result.get_usage(sentence)), + ) def str2bool(str): - return True if str.lower() == 'true' else False + return True if str.lower() == "true" else False def complete_url(url: str) -> str: parsed = urlparse(url) - base_url = ''.join([parsed.scheme, '://', parsed.netloc]) - dashscope.base_websocket_api_url = '/'.join( - [base_url, 'api-ws', dashscope.common.env.api_version, 'inference']) - dashscope.base_http_api_url = url = '/'.join( - [base_url, 'api', dashscope.common.env.api_version]) - print('Set base_websocket_api_url: ', dashscope.base_websocket_api_url) - print('Set base_http_api_url: ', dashscope.base_http_api_url) + base_url = "".join([parsed.scheme, "://", parsed.netloc]) + dashscope.base_websocket_api_url = "/".join( + [base_url, "api-ws", dashscope.common.env.api_version, "inference"], + ) + dashscope.base_http_api_url = url = "/".join( + [base_url, "api", dashscope.common.env.api_version], + ) + print("Set base_websocket_api_url: ", dashscope.base_websocket_api_url) + print("Set base_http_api_url: ", dashscope.base_http_api_url) @pytest.mark.skip def test_by_user(): parser = argparse.ArgumentParser() - parser.add_argument('--model', type=str, default='paraformer-realtime-v1') - parser.add_argument('--format', type=str, default='pcm') - parser.add_argument('--sample_rate', type=int, default=16000) - parser.add_argument('--file', - type=str, - default='./tests/data/asr_example.wav') - parser.add_argument('--sync', type=str2bool, default='False') - parser.add_argument('--phrase_id', type=str, default=None) - parser.add_argument('--disfluency_removal_enabled', - type=str2bool, - default='False') - parser.add_argument('--diarization_enabled', - type=str2bool, - default='False') - parser.add_argument('--speaker_count', type=int, default=None) - parser.add_argument('--timestamp_alignment_enabled', - type=str2bool, - default='False') - parser.add_argument('--special_word_filter', type=str, default=None) - parser.add_argument('--audio_event_detection_enabled', - type=str2bool, - default='False') - parser.add_argument('--api_key', type=str) - parser.add_argument('--base_url', type=str) + parser.add_argument("--model", type=str, default="paraformer-realtime-v1") + parser.add_argument("--format", type=str, default="pcm") + parser.add_argument("--sample_rate", type=int, default=16000) + parser.add_argument( + "--file", + type=str, + default="./tests/data/asr_example.wav", + ) + parser.add_argument("--sync", type=str2bool, default="False") + parser.add_argument("--phrase_id", type=str, default=None) + parser.add_argument( + "--disfluency_removal_enabled", + type=str2bool, + default="False", + ) + parser.add_argument( + "--diarization_enabled", + type=str2bool, + default="False", + ) + parser.add_argument("--speaker_count", type=int, default=None) + parser.add_argument( + "--timestamp_alignment_enabled", + type=str2bool, + default="False", + ) + parser.add_argument("--special_word_filter", type=str, default=None) + parser.add_argument( + "--audio_event_detection_enabled", + type=str2bool, + default="False", + ) + parser.add_argument("--api_key", type=str) + parser.add_argument("--base_url", type=str) args = parser.parse_args() if args.api_key is not None: @@ -162,7 +184,8 @@ def test_by_user(): timestamp_alignment_enabled=args.timestamp_alignment_enabled, special_word_filter=args.special_word_filter, audio_event_detection_enabled=args.audio_event_detection_enabled, - callback=callback) + callback=callback, + ) phrase_id = args.phrase_id @@ -172,17 +195,19 @@ def test_by_user(): sentences: List[Any] = result.get_sentence() if sentences and len(sentences) > 0: for sentence in sentences: - print('Recognizing: %s, usage: %s' % - (sentence, result.get_usage(sentence))) + print( + "Recognizing: %s, usage: %s" + % (sentence, result.get_usage(sentence)), + ) else: - print('Warn: get an empty recognition result: ', result) + print("Warn: get an empty recognition result: ", result) else: - print('Error: ', result.message) + print("Error: ", result.message) else: recognition.start(phrase_id=phrase_id) try: - f = open(args.file, 'rb') + f = open(args.file, "rb") while True: chunk = f.read(3200) if not chunk: @@ -192,10 +217,10 @@ def test_by_user(): time.sleep(0.1) f.close() except Exception as e: - print('Open file or send audio failed:', e) + print("Open file or send audio failed:", e) recognition.stop() -if __name__ == '__main__': +if __name__ == "__main__": test_by_user() diff --git a/tests/test_speech_synthesis.py b/tests/test_speech_synthesis.py index 7dd3b63..c25edbc 100644 --- a/tests/test_speech_synthesis.py +++ b/tests/test_speech_synthesis.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import argparse @@ -10,11 +11,14 @@ import dashscope from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse -from dashscope.audio.tts import (ResultCallback, SpeechSynthesisResult, - SpeechSynthesizer) +from dashscope.audio.tts import ( + ResultCallback, + SpeechSynthesisResult, + SpeechSynthesizer, +) from tests.base_test import BaseTestEnvironment -logger = logging.getLogger('dashscope') +logger = logging.getLogger("dashscope") logger.setLevel(logging.DEBUG) # create console handler and set level to debug console_handler = logging.StreamHandler() @@ -22,7 +26,8 @@ # create formatter formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + "%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) # add formatter to ch console_handler.setFormatter(formatter) @@ -44,27 +49,37 @@ def on_event(self, result: SpeechSynthesisResult): assert audio_frame is not None or timestamp is not None if audio_frame is not None: - assert (sys.getsizeof(audio_frame) > 0) + assert sys.getsizeof(audio_frame) > 0 if timestamp is not None: - assert 'begin_time' in timestamp - assert 'end_time' in timestamp - assert len(timestamp['words']) > 0 + assert "begin_time" in timestamp + assert "end_time" in timestamp + assert len(timestamp["words"]) > 0 class TestSynthesis(BaseTestEnvironment): @classmethod def setup_class(cls): super().setup_class() - cls.model = 'sambert-zhichu-v1' - cls.text = '今天天气真不错,我想去操场踢足球。' + cls.model = "sambert-zhichu-v1" + cls.text = "今天天气真不错,我想去操场踢足球。" def check_result(self, result): assert result.get_response().status_code == HTTPStatus.OK - assert result.get_response().code is None or len( - result.get_response().code) == 0 - assert result.get_response().message is None or len( - result.get_response().message) == 0 + assert ( + result.get_response().code is None + or len( + result.get_response().code, + ) + == 0 + ) + assert ( + result.get_response().message is None + or len( + result.get_response().message, + ) + == 0 + ) assert sys.getsizeof(result.get_audio_data()) > 0 def test_sync_call_with_multi_formats(self): @@ -74,103 +89,124 @@ def test_sync_call_with_multi_formats(self): model=self.model, text=self.text, callback=test_callback, - format=SpeechSynthesizer.AudioFormat.format_mp3) + format=SpeechSynthesizer.AudioFormat.format_mp3, + ) self.check_result(result) result = SpeechSynthesizer.call( model=self.model, text=self.text, callback=test_callback, - format=SpeechSynthesizer.AudioFormat.format_pcm) + format=SpeechSynthesizer.AudioFormat.format_pcm, + ) self.check_result(result) result = SpeechSynthesizer.call( model=self.model, text=self.text, callback=test_callback, - format=SpeechSynthesizer.AudioFormat.format_wav) + format=SpeechSynthesizer.AudioFormat.format_wav, + ) self.check_result(result) def test_sync_call_with_sample_rate(self): test_callback = TestCallback() - result = SpeechSynthesizer.call(model=self.model, - text=self.text, - callback=test_callback, - sample_rate=16000) + result = SpeechSynthesizer.call( + model=self.model, + text=self.text, + callback=test_callback, + sample_rate=16000, + ) self.check_result(result) - result = SpeechSynthesizer.call(model=self.model, - text=self.text, - callback=test_callback, - sample_rate=24000) + result = SpeechSynthesizer.call( + model=self.model, + text=self.text, + callback=test_callback, + sample_rate=24000, + ) self.check_result(result) def test_sync_call_with_volume(self): test_callback = TestCallback() - result = SpeechSynthesizer.call(model=self.model, - text=self.text, - callback=test_callback, - volume=1) + result = SpeechSynthesizer.call( + model=self.model, + text=self.text, + callback=test_callback, + volume=1, + ) self.check_result(result) - result = SpeechSynthesizer.call(model=self.model, - text=self.text, - callback=test_callback, - volume=100) + result = SpeechSynthesizer.call( + model=self.model, + text=self.text, + callback=test_callback, + volume=100, + ) self.check_result(result) def test_sync_call_with_rate(self): test_callback = TestCallback() - result = SpeechSynthesizer.call(model=self.model, - text=self.text, - callback=test_callback, - rate=-500) + result = SpeechSynthesizer.call( + model=self.model, + text=self.text, + callback=test_callback, + rate=-500, + ) self.check_result(result) - result = SpeechSynthesizer.call(model=self.model, - text=self.text, - callback=test_callback, - rate=500) + result = SpeechSynthesizer.call( + model=self.model, + text=self.text, + callback=test_callback, + rate=500, + ) self.check_result(result) - result = SpeechSynthesizer.call(model=self.model, - text=self.text, - callback=test_callback, - pitch=-500) + result = SpeechSynthesizer.call( + model=self.model, + text=self.text, + callback=test_callback, + pitch=-500, + ) self.check_result(result) - result = SpeechSynthesizer.call(model=self.model, - text=self.text, - callback=test_callback, - pitch=500) + result = SpeechSynthesizer.call( + model=self.model, + text=self.text, + callback=test_callback, + pitch=500, + ) self.check_result(result) def test_sync_call_with_timestamp(self): test_callback = TestCallback() - result = SpeechSynthesizer.call(model=self.model, - text=self.text, - callback=test_callback, - word_timestamp_enabled=True, - phoneme_timestamp_enabled=True) + result = SpeechSynthesizer.call( + model=self.model, + text=self.text, + callback=test_callback, + word_timestamp_enabled=True, + phoneme_timestamp_enabled=True, + ) self.check_result(result) class Callback(ResultCallback): def on_open(self): - print('Synthesis is opened.') + print("Synthesis is opened.") def on_complete(self): - print('Synthesis is completed.') + print("Synthesis is completed.") def on_error(self, response: SpeechSynthesisResponse): - print('Synthesis failed, response is %s' % (str(response))) + print("Synthesis failed, response is %s" % (str(response))) def on_close(self): - print('Synthesis is closed.') + print("Synthesis is closed.") def on_event(self, result: SpeechSynthesisResult): global first_data_flag @@ -185,54 +221,72 @@ def on_event(self, result: SpeechSynthesisResult): first_data_time = cur_time - call_time first_data_flag = False - print('get binary data: ', sys.getsizeof(audio_frame)) + print("get binary data: ", sys.getsizeof(audio_frame)) if timestamp is not None: - if 'begin_time' in timestamp and 'end_time' in timestamp: - print(' time: %d - %d' % - (timestamp['begin_time'], timestamp['end_time'])) - words_list = timestamp['words'] - print(' words: ') + if "begin_time" in timestamp and "end_time" in timestamp: + print( + " time: %d - %d" + % (timestamp["begin_time"], timestamp["end_time"]), + ) + words_list = timestamp["words"] + print(" words: ") for word in words_list: - print(' %d - %d : %s' % - (word['begin_time'], word['end_time'], word['text'])) - if 'phonemes' in word: - for phoneme in word['phonemes']: - print(' %d - %d : text: %s tone: %s' % - (phoneme['begin_time'], phoneme['end_time'], - phoneme['text'], phoneme['tone'])) + print( + " %d - %d : %s" + % (word["begin_time"], word["end_time"], word["text"]), + ) + if "phonemes" in word: + for phoneme in word["phonemes"]: + print( + " %d - %d : text: %s tone: %s" + % ( + phoneme["begin_time"], + phoneme["end_time"], + phoneme["text"], + phoneme["tone"], + ), + ) def str2bool(str): - return True if str.lower() == 'true' else False + return True if str.lower() == "true" else False @pytest.mark.skip def test_by_user(): parser = argparse.ArgumentParser() - parser.add_argument('--model', type=str, default='sambert-zhichu-v1') - parser.add_argument('--text', type=str, default='今天天气真不错,我想去操场踢足球。') - parser.add_argument('--callback', - type=str2bool, - default='False', - help='run with callback or not.') - parser.add_argument('--format', - type=str, - default=SpeechSynthesizer.AudioFormat.format_wav) - parser.add_argument('--sample_rate', type=int, default=48000) - parser.add_argument('--volume', type=int, default=50) - parser.add_argument('--rate', type=float, default=1.0) - parser.add_argument('--pitch', type=float, default=1.0) - parser.add_argument('--word_timestamp', - type=str2bool, - default='False', - help='run with word_timestamp or not.') - parser.add_argument('--phoneme_timestamp', - type=str2bool, - default='False', - help='run with phoneme_timestamp or not.') - parser.add_argument('--api_key', type=str) - parser.add_argument('--base_url', type=str) + parser.add_argument("--model", type=str, default="sambert-zhichu-v1") + parser.add_argument("--text", type=str, default="今天天气真不错,我想去操场踢足球。") + parser.add_argument( + "--callback", + type=str2bool, + default="False", + help="run with callback or not.", + ) + parser.add_argument( + "--format", + type=str, + default=SpeechSynthesizer.AudioFormat.format_wav, + ) + parser.add_argument("--sample_rate", type=int, default=48000) + parser.add_argument("--volume", type=int, default=50) + parser.add_argument("--rate", type=float, default=1.0) + parser.add_argument("--pitch", type=float, default=1.0) + parser.add_argument( + "--word_timestamp", + type=str2bool, + default="False", + help="run with word_timestamp or not.", + ) + parser.add_argument( + "--phoneme_timestamp", + type=str2bool, + default="False", + help="run with phoneme_timestamp or not.", + ) + parser.add_argument("--api_key", type=str) + parser.add_argument("--base_url", type=str) args = parser.parse_args() if args.api_key is not None: @@ -258,18 +312,23 @@ def test_by_user(): rate=args.rate, pitch=args.pitch, word_timestamp_enabled=args.word_timestamp, - phoneme_timestamp_enabled=args.phoneme_timestamp) - print('Speech synthesis finish:') + phoneme_timestamp_enabled=args.phoneme_timestamp, + ) + print("Speech synthesis finish:") if result.get_audio_data() is not None: - print(' get audio data: %dbytes' % - (sys.getsizeof(result.get_audio_data()))) - print(' get sentences size: %d' % (len(result.get_timestamps()))) - print(' get response: %s' % (result.get_response())) + print( + " get audio data: %dbytes" + % (sys.getsizeof(result.get_audio_data())), + ) + print(" get sentences size: %d" % (len(result.get_timestamps()))) + print(" get response: %s" % (result.get_response())) if first_data_time > 0: - print('The cost time of first audio data: %6dms' % - (first_data_time * 1000)) + print( + "The cost time of first audio data: %6dms" + % (first_data_time * 1000), + ) -if __name__ == '__main__': +if __name__ == "__main__": test_by_user() diff --git a/tests/test_speech_synthesis_v2.py b/tests/test_speech_synthesis_v2.py index 060a8fd..b758673 100644 --- a/tests/test_speech_synthesis_v2.py +++ b/tests/test_speech_synthesis_v2.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import pytest @@ -8,64 +9,67 @@ class TestCallback(ResultCallback): def on_open(self): - print('websocket is open.') + print("websocket is open.") def on_complete(self): - print('speech synthesis task complete successfully.') + print("speech synthesis task complete successfully.") def on_error(self, message: str): - print(f'speech synthesis task failed, {message}') + print(f"speech synthesis task failed, {message}") def on_close(self): - print('websocket is closed.') + print("websocket is closed.") self.file.close() def on_event(self, message): - print(f'recv speech synthsis message {message}') + print(f"recv speech synthsis message {message}") def on_data(self, data: bytes) -> None: # save audio to file - print('recv speech audio {}'.format(len(data))) + print("recv speech audio {}".format(len(data))) class TestSynthesis(BaseTestEnvironment): @classmethod def setup_class(cls): super().setup_class() - cls.model = 'pre-cosyvoice-test' - cls.voice = 'longxiaochun' + cls.model = "pre-cosyvoice-test" + cls.voice = "longxiaochun" cls.text_array = [ - '流式文本语音合成SDK,', - '可以将输入的文本', - '合成为语音二进制数据,', - '相比于非流式语音合成,', - '流式合成的优势在于实时性', - '更强。用户在输入文本的同时', - '可以听到接近同步的语音输出,', - '极大地提升了交互体验,', - '减少了用户等待时间。', - '适用于调用大规模', - '语言模型(LLM),以', - '流式输入文本的方式', - '进行语音合成的场景。', + "流式文本语音合成SDK,", + "可以将输入的文本", + "合成为语音二进制数据,", + "相比于非流式语音合成,", + "流式合成的优势在于实时性", + "更强。用户在输入文本的同时", + "可以听到接近同步的语音输出,", + "极大地提升了交互体验,", + "减少了用户等待时间。", + "适用于调用大规模", + "语言模型(LLM),以", + "流式输入文本的方式", + "进行语音合成的场景。", ] @pytest.mark.skip def test_sync_call_with_multi_formats(self): - - synthesizer = SpeechSynthesizer(model=self.model, - voice=self.voice, - url=self.url) + synthesizer = SpeechSynthesizer( + model=self.model, + voice=self.voice, + url=self.url, + ) audio = synthesizer.call(self.text_array[0]) - print('recv audio length {}'.format(len(audio))) + print("recv audio length {}".format(len(audio))) @pytest.mark.skip def test_sync_streaming_call_with_multi_formats(self): test_callback = TestCallback() - synthesizer = SpeechSynthesizer(model=self.model, - voice=self.voice, - callback=test_callback) + synthesizer = SpeechSynthesizer( + model=self.model, + voice=self.voice, + callback=test_callback, + ) for text in self.text_array: synthesizer.streaming_call(text) synthesizer.streaming_complete() @@ -74,9 +78,11 @@ def test_sync_streaming_call_with_multi_formats(self): def test_sync_streaming_call_cancel_with_multi_formats(self): test_callback = TestCallback() - synthesizer = SpeechSynthesizer(model=self.model, - voice=self.voice, - callback=test_callback) + synthesizer = SpeechSynthesizer( + model=self.model, + voice=self.voice, + callback=test_callback, + ) for text in self.text_array: synthesizer.streaming_call(text) synthesizer.streaming_cancel() diff --git a/tests/test_speech_transcription.py b/tests/test_speech_transcription.py index 8ff83a1..aa9f770 100644 --- a/tests/test_speech_transcription.py +++ b/tests/test_speech_transcription.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import argparse @@ -13,10 +14,10 @@ from dashscope.common.constants import TaskStatus from tests.base_test import BaseTestEnvironment -HTTPS_16K_CH1_WAV = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav' # noqa: * -HTTPS_16K_CH2_WAV = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_2ch.wav' # noqa: * +HTTPS_16K_CH1_WAV = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav" # noqa: * +HTTPS_16K_CH2_WAV = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_2ch.wav" # noqa: * -logger = logging.getLogger('dashscope') +logger = logging.getLogger("dashscope") logger.setLevel(logging.DEBUG) # create console handler and set level to debug console_handler = logging.StreamHandler() @@ -24,7 +25,8 @@ # create formatter formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + "%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) # add formatter to ch console_handler.setFormatter(formatter) @@ -39,8 +41,10 @@ def setup_class(cls): cls.model = Transcription.Models.paraformer_8k_v1 def test_async_call_and_wait(self): - task = Transcription.async_call(model=self.model, - file_urls=[HTTPS_16K_CH1_WAV]) + task = Transcription.async_call( + model=self.model, + file_urls=[HTTPS_16K_CH1_WAV], + ) # wait results with task_id. results = task @@ -58,8 +62,10 @@ def test_async_call_and_wait(self): assert results.output.task_status == TaskStatus.SUCCEEDED def test_async_call_and_fetch(self): - task = Transcription.async_call(model=self.model, - file_urls=[HTTPS_16K_CH1_WAV]) + task = Transcription.async_call( + model=self.model, + file_urls=[HTTPS_16K_CH1_WAV], + ) # poll results with task_id. results = task @@ -67,9 +73,11 @@ def test_async_call_and_fetch(self): while True: results = Transcription.fetch(task.output.task_id) if results.status_code == HTTPStatus.OK: - if (results.output is not None - and results.output.task_status - in [TaskStatus.PENDING, TaskStatus.RUNNING]): + if ( + results.output is not None + and results.output.task_status + in [TaskStatus.PENDING, TaskStatus.RUNNING] + ): time.sleep(2) continue @@ -82,8 +90,11 @@ def test_async_call_and_fetch(self): while True: results = Transcription.fetch(task) if results.status_code == HTTPStatus.OK: - if (results.output is not None and results.output.task_status - in [TaskStatus.PENDING, TaskStatus.RUNNING]): + if ( + results.output is not None + and results.output.task_status + in [TaskStatus.PENDING, TaskStatus.RUNNING] + ): time.sleep(2) continue @@ -93,61 +104,77 @@ def test_async_call_and_fetch(self): assert results.output.task_status == TaskStatus.SUCCEEDED def test_sync_call(self): - results = Transcription.call(model=self.model, - file_urls=[HTTPS_16K_CH1_WAV]) + results = Transcription.call( + model=self.model, + file_urls=[HTTPS_16K_CH1_WAV], + ) assert results.status_code == HTTPStatus.OK assert results.output is not None assert results.output.task_status == TaskStatus.SUCCEEDED def test_sync_call_with_2ch(self): - results = Transcription.call(model=self.model, - file_urls=[HTTPS_16K_CH2_WAV], - channel_id=[0, 1]) + results = Transcription.call( + model=self.model, + file_urls=[HTTPS_16K_CH2_WAV], + channel_id=[0, 1], + ) assert results.status_code == HTTPStatus.OK assert results.output is not None assert results.output.task_status == TaskStatus.SUCCEEDED def str2bool(str): - return True if str.lower() == 'true' else False + return True if str.lower() == "true" else False def complete_url(url: str) -> str: parsed = urlparse(url) - base_url = ''.join([parsed.scheme, '://', parsed.netloc]) - dashscope.base_websocket_api_url = '/'.join( - [base_url, 'api-ws', dashscope.common.env.api_version, 'inference']) - dashscope.base_http_api_url = url = '/'.join( - [base_url, 'api', dashscope.common.env.api_version]) - print('Set base_websocket_api_url: ', dashscope.base_websocket_api_url) - print('Set base_http_api_url: ', dashscope.base_http_api_url) + base_url = "".join([parsed.scheme, "://", parsed.netloc]) + dashscope.base_websocket_api_url = "/".join( + [base_url, "api-ws", dashscope.common.env.api_version, "inference"], + ) + dashscope.base_http_api_url = url = "/".join( + [base_url, "api", dashscope.common.env.api_version], + ) + print("Set base_websocket_api_url: ", dashscope.base_websocket_api_url) + print("Set base_http_api_url: ", dashscope.base_http_api_url) @pytest.mark.skip def test_by_user(): parser = argparse.ArgumentParser() - parser.add_argument('--model', - type=str, - default=Transcription.Models.paraformer_v1) - parser.add_argument('--files', type=str, default=HTTPS_16K_CH1_WAV) - parser.add_argument('--sync', type=str2bool, default='False') - parser.add_argument('--phrase_id', type=str, default=None) - parser.add_argument('--disfluency_removal_enabled', - type=str2bool, - default='False') - parser.add_argument('--diarization_enabled', - type=str2bool, - default='False') - parser.add_argument('--speaker_count', type=int, default=None) - parser.add_argument('--timestamp_alignment_enabled', - type=str2bool, - default='False') - parser.add_argument('--special_word_filter', type=str, default=None) - parser.add_argument('--audio_event_detection_enabled', - type=str2bool, - default='False') - parser.add_argument('--api_key', type=str) - parser.add_argument('--base_url', type=str) + parser.add_argument( + "--model", + type=str, + default=Transcription.Models.paraformer_v1, + ) + parser.add_argument("--files", type=str, default=HTTPS_16K_CH1_WAV) + parser.add_argument("--sync", type=str2bool, default="False") + parser.add_argument("--phrase_id", type=str, default=None) + parser.add_argument( + "--disfluency_removal_enabled", + type=str2bool, + default="False", + ) + parser.add_argument( + "--diarization_enabled", + type=str2bool, + default="False", + ) + parser.add_argument("--speaker_count", type=int, default=None) + parser.add_argument( + "--timestamp_alignment_enabled", + type=str2bool, + default="False", + ) + parser.add_argument("--special_word_filter", type=str, default=None) + parser.add_argument( + "--audio_event_detection_enabled", + type=str2bool, + default="False", + ) + parser.add_argument("--api_key", type=str) + parser.add_argument("--base_url", type=str) args = parser.parse_args() if args.api_key is not None: @@ -167,8 +194,9 @@ def test_by_user(): speaker_count=args.speaker_count, timestamp_alignment_enabled=args.timestamp_alignment_enabled, special_word_filter=args.special_word_filter, - audio_event_detection_enabled=args.audio_event_detection_enabled) - print('sync output: ', results.output) + audio_event_detection_enabled=args.audio_event_detection_enabled, + ) + print("sync output: ", results.output) else: task = Transcription.async_call( model=args.model, @@ -179,33 +207,36 @@ def test_by_user(): speaker_count=args.speaker_count, timestamp_alignment_enabled=args.timestamp_alignment_enabled, special_word_filter=args.special_word_filter, - audio_event_detection_enabled=args.audio_event_detection_enabled) - print('async task code: ', task.status_code) - print('async task output: ', task.output) + audio_event_detection_enabled=args.audio_event_detection_enabled, + ) + print("async task code: ", task.status_code) + print("async task output: ", task.output) results = None if task.status_code == HTTPStatus.OK: while True: results = Transcription.fetch(task) if results.status_code == HTTPStatus.OK: - if (results.output is not None - and results.output.task_status - in [TaskStatus.PENDING, TaskStatus.RUNNING]): + if ( + results.output is not None + and results.output.task_status + in [TaskStatus.PENDING, TaskStatus.RUNNING] + ): time.sleep(2) continue break - print('async output: ', results.output) - print('async task_status of output: ', results.output.task_status) - print('async results of output: ', results.output.results) + print("async output: ", results.output) + print("async task_status of output: ", results.output.task_status) + print("async results of output: ", results.output.results) results = Transcription.wait(task) - print('async output with wait: ', results.output) + print("async output with wait: ", results.output) else: - print('async failed') + print("async failed") -if __name__ == '__main__': +if __name__ == "__main__": test_by_user() diff --git a/tests/test_text_embedding.py b/tests/test_text_embedding.py index 86cebe6..64699fd 100644 --- a/tests/test_text_embedding.py +++ b/tests/test_text_embedding.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from http import HTTPStatus @@ -8,20 +9,26 @@ class TestTextEmbeddingRequest(MockRequestBase): def test_call_with_string(self, http_server): - resp = TextEmbedding.call(model=TextEmbedding.Models.text_embedding_v3, - input='hello') + resp = TextEmbedding.call( + model=TextEmbedding.Models.text_embedding_v3, + input="hello", + ) assert resp.status_code == HTTPStatus.OK - assert len(resp.output['embeddings']) == 1 + assert len(resp.output["embeddings"]) == 1 def test_call_with_list_str(self, http_server): - resp = TextEmbedding.call(model=TextEmbedding.Models.text_embedding_v3, - input=['hello', 'world']) + resp = TextEmbedding.call( + model=TextEmbedding.Models.text_embedding_v3, + input=["hello", "world"], + ) assert resp.status_code == HTTPStatus.OK - assert len(resp.output['embeddings']) == 1 + assert len(resp.output["embeddings"]) == 1 def test_call_with_opened_file(self, http_server): - with open('tests/data/multi_line.txt') as f: + with open("tests/data/multi_line.txt") as f: response = TextEmbedding.call( - model=TextEmbedding.Models.text_embedding_v3, input=f) + model=TextEmbedding.Models.text_embedding_v3, + input=f, + ) assert response.status_code == HTTPStatus.OK - assert len(response.output['embeddings']) == 1 + assert len(response.output["embeddings"]) == 1 diff --git a/tests/test_threads.py b/tests/test_threads.py index b88298c..744ac92 100644 --- a/tests/test_threads.py +++ b/tests/test_threads.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -11,12 +12,12 @@ class TestThreads(MockServerBase): def test_create_with_no_messages(self, mock_server: MockServer): thread_id = str(uuid.uuid4()) - metadata = {'key': 'value'} + metadata = {"key": "value"} response_obj = { - 'id': thread_id, - 'object': 'thread', - 'created_at': 1699012949, - 'metadata': metadata + "id": thread_id, + "object": "thread", + "created_at": 1699012949, + "metadata": metadata, } response_body = json.dumps(response_obj) mock_server.responses.put(response_body) @@ -24,42 +25,45 @@ def test_create_with_no_messages(self, mock_server: MockServer): req = mock_server.requests.get(block=True) assert response.id == thread_id assert response.metadata == metadata - req['metadata'] == metadata + req["metadata"] == metadata def test_create_with_messages(self, mock_server: MockServer): thread_id = str(uuid.uuid4()) - metadata = {'key': 'value'} + metadata = {"key": "value"} response_obj = { - 'id': thread_id, - 'object': 'thread', - 'created_at': 1699012949, - 'metadata': metadata + "id": thread_id, + "object": "thread", + "created_at": 1699012949, + "metadata": metadata, } response_body = json.dumps(response_obj) mock_server.responses.put(response_body) - messages = [{ - 'role': 'user', - 'content': 'How does AI work? Explain it in simple terms.', - 'file_ids': ['123'] - }, { - 'role': 'user', - 'content': '画幅画' - }] + messages = [ + { + "role": "user", + "content": "How does AI work? Explain it in simple terms.", + "file_ids": ["123"], + }, + { + "role": "user", + "content": "画幅画", + }, + ] thread = Threads.create(messages=messages) assert thread.id == thread_id assert thread.metadata == metadata req = mock_server.requests.get(block=True) - req['messages'] == messages + req["messages"] == messages def test_retrieve(self, mock_server: MockServer): thread_id = str(uuid.uuid4()) - metadata = {'key': 'value'} + metadata = {"key": "value"} response_obj = { - 'id': thread_id, - 'object': 'thread', - 'created_at': 1699012949, - 'metadata': metadata + "id": thread_id, + "object": "thread", + "created_at": 1699012949, + "metadata": metadata, } response_body = json.dumps(response_obj) mock_server.responses.put(response_body) @@ -72,29 +76,29 @@ def test_retrieve(self, mock_server: MockServer): def test_update(self, mock_server: MockServer): thread_id = str(uuid.uuid4()) - metadata = {'key': 'value'} + metadata = {"key": "value"} response_obj = { - 'id': thread_id, - 'object': 'thread', - 'created_at': 1699012949, - 'metadata': metadata + "id": thread_id, + "object": "thread", + "created_at": 1699012949, + "metadata": metadata, } response_body = json.dumps(response_obj) mock_server.responses.put(response_body) response = Threads.update(thread_id, metadata=metadata) # get thread id we send. req = mock_server.requests.get(block=True) - assert req['metadata'] == metadata + assert req["metadata"] == metadata assert response.id == thread_id assert response.metadata == metadata def test_delete(self, mock_server: MockServer): thread_id = str(uuid.uuid4()) response_obj = { - 'id': thread_id, - 'object': 'thread', - 'created_at': 1699012949, - 'deleted': True + "id": thread_id, + "object": "thread", + "created_at": 1699012949, + "deleted": True, } mock_server.responses.put(json.dumps(response_obj)) response = Threads.delete(thread_id) diff --git a/tests/test_tokenization.py b/tests/test_tokenization.py index 7b85726..3fa7683 100644 --- a/tests/test_tokenization.py +++ b/tests/test_tokenization.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -10,30 +11,29 @@ class TestTokenization(MockServerBase): text_response_obj = { - 'output': { - 'token_ids': [115798, 198], - 'tokens': ['<|im_start|>', '\n'], - 'prompt': - '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n如何做土豆炖猪脚?<|im_end|>\n<|im_start|>assistant\n' # noqa E501 + "output": { + "token_ids": [115798, 198], + "tokens": ["<|im_start|>", "\n"], + "prompt": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n如何做土豆炖猪脚?<|im_end|>\n<|im_start|>assistant\n", # noqa E501 }, - 'usage': { - 'input_tokens': 28 + "usage": { + "input_tokens": 28, }, - 'request_id': 'c25e14cf-f986-900b-853c-4644a3196f39' + "request_id": "c25e14cf-f986-900b-853c-4644a3196f39", } def test_default_no_resources_request(self, mock_server: MockServer): response_str = json.dumps(TestTokenization.text_response_obj) mock_server.responses.put(response_str) - prompt = 'hello' + prompt = "hello" model = Tokenization.Models.qwen_turbo response = Tokenization.call(model=model, prompt=prompt) req = mock_server.requests.get(block=True) - assert req['model'] == model - assert req['parameters'] == {} - assert req['input'] == {'prompt': prompt} + assert req["model"] == model + assert req["parameters"] == {} + assert req["input"] == {"prompt": prompt} assert response.status_code == HTTPStatus.OK - assert len(response.output['token_ids']) == 2 - assert len(response.output['tokens']) == 2 - assert response.usage['input_tokens'] == 28 + assert len(response.output["token_ids"]) == 2 + assert len(response.output["tokens"]) == 2 + assert response.usage["input_tokens"] == 28 diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index ca69382..903e1f6 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import os @@ -5,28 +6,44 @@ from dashscope.tokenizers.tokenizer import get_tokenizer -class TestTokenization(): +class TestTokenization: @classmethod def setup_class(cls): # install tiktoken - os.system('pip install tiktoken') + os.system("pip install tiktoken") def test_encode_decode(self): - tokenizer = get_tokenizer('qwen-7b-chat') + tokenizer = get_tokenizer("qwen-7b-chat") - input_str = '这个是千问tokenizer' + input_str = "这个是千问tokenizer" - tokens = tokenizer.encode('这个是千问tokenizer') + tokens = tokenizer.encode("这个是千问tokenizer") decoded_str = tokenizer.decode(tokens) - assert (input_str == decoded_str) - - assert ([151643] == tokenizer.encode('<|endoftext|>', - allowed_special={'<|endoftext|>' - })) - assert ([151643] == tokenizer.encode('<|endoftext|>', - allowed_special='all')) - assert ([27, 91, 8691, 723, 427, 91, - 29] == tokenizer.encode('<|endoftext|>', - allowed_special=set())) - assert ([151643] == tokenizer.encode('<|endoftext|>', - disallowed_special=set())) + assert input_str == decoded_str + + assert [151643] == tokenizer.encode( + "<|endoftext|>", + allowed_special={ + "<|endoftext|>", + }, + ) + assert [151643] == tokenizer.encode( + "<|endoftext|>", + allowed_special="all", + ) + assert [ + 27, + 91, + 8691, + 723, + 427, + 91, + 29, + ] == tokenizer.encode( + "<|endoftext|>", + allowed_special=set(), + ) + assert [151643] == tokenizer.encode( + "<|endoftext|>", + disallowed_special=set(), + ) diff --git a/tests/test_translation_recognizer.py b/tests/test_translation_recognizer.py index 27e959e..6031779 100644 --- a/tests/test_translation_recognizer.py +++ b/tests/test_translation_recognizer.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import os @@ -7,7 +8,10 @@ from dashscope.audio.asr import TranslationRecognizerCallback from dashscope.audio.asr.translation_recognizer import ( - TranscriptionResult, TranslationRecognizerRealtime, TranslationResult) + TranscriptionResult, + TranslationRecognizerRealtime, + TranslationResult, +) from tests.base_test import BaseTestEnvironment @@ -16,14 +20,14 @@ def __init__(self, tag, file_path) -> None: super().__init__() self.tag = tag self.file_path = file_path - self.text = '' - self.translate_text = '' + self.text = "" + self.translate_text = "" def on_open(self) -> None: - print(f'[{self.tag}] TranslationRecognizerCallback open.') + print(f"[{self.tag}] TranslationRecognizerCallback open.") def on_close(self) -> None: - print(f'[{self.tag}] TranslationRecognizerCallback close.') + print(f"[{self.tag}] TranslationRecognizerCallback close.") def on_event( self, @@ -33,7 +37,7 @@ def on_event( usage, ) -> None: if translation_result is not None: - translation = translation_result.get_translation('en') + translation = translation_result.get_translation("en") # print(f'[{self.tag}]RecognitionCallback text: ', sentence['text']) partial recognition result if translation.is_sentence_end: self.translate_text = self.translate_text + translation.text @@ -42,26 +46,26 @@ def on_event( self.text = self.text + transcription_result.text def on_error(self, message) -> None: - print('error: {}'.format(message)) + print("error: {}".format(message)) def on_complete(self) -> None: - print(f'[{self.tag}] Transcript ==> ', self.text) - print(f'[{self.tag}] Translate ==> ', self.translate_text) - print(f'[{self.tag}] Translation completed') # translation complete + print(f"[{self.tag}] Transcript ==> ", self.text) + print(f"[{self.tag}] Translate ==> ", self.translate_text) + print(f"[{self.tag}] Translation completed") # translation complete class TestSynthesis(BaseTestEnvironment): @classmethod def setup_class(cls): super().setup_class() - cls.model = 'gummy-realtime-v1' - cls.format = 'pcm' + cls.model = "gummy-realtime-v1" + cls.format = "pcm" cls.sample_rate = 16000 - cls.file = './tests/data/asr_example.wav' + cls.file = "./tests/data/asr_example.wav" @pytest.mark.skip def test_translate_from_file(self): - callback = Callback(f'process {os.getpid()}', self.file) + callback = Callback(f"process {os.getpid()}", self.file) # Call translation service by async mode, you can customize the translation parameters, like model, format, # sample_rate For more information, please refer to https://help.aliyun.com/document_detail/2712536.html @@ -71,7 +75,7 @@ def test_translate_from_file(self): sample_rate=self.sample_rate, transcription_enabled=True, translation_enabled=True, - translation_target_languages=['en'], + translation_target_languages=["en"], callback=callback, ) @@ -80,7 +84,7 @@ def test_translate_from_file(self): try: audio_data: bytes = None - f = open(self.file, 'rb') + f = open(self.file, "rb") if os.path.getsize(self.file): while True: audio_data = f.read(3200) @@ -91,16 +95,17 @@ def test_translate_from_file(self): time.sleep(0.01) else: raise Exception( - 'The supplied file was empty (zero bytes long)') + "The supplied file was empty (zero bytes long)", + ) f.close() except Exception as e: raise e translator.stop() print( - '[Metric] requestId: {}, first package delay ms: {}, last package delay ms: {}' - .format( + "[Metric] requestId: {}, first package delay ms: {}, last package delay ms: {}".format( translator.get_last_request_id(), translator.get_first_package_delay(), translator.get_last_package_delay(), - )) + ), + ) diff --git a/tests/test_understanding.py b/tests/test_understanding.py index ee58c9a..a0c340a 100644 --- a/tests/test_understanding.py +++ b/tests/test_understanding.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json @@ -12,39 +13,41 @@ class TestUnderstandingRequest(MockServerBase): text_response_obj = { - 'status_code': 200, - 'request_id': 'addddde9-1b8b-9707-979a-09b61f91302d', - 'code': '', - 'message': '', - 'output': { - 'rt': 0.06531630805693567, - 'text': '积极;' + "status_code": 200, + "request_id": "addddde9-1b8b-9707-979a-09b61f91302d", + "code": "", + "message": "", + "output": { + "rt": 0.06531630805693567, + "text": "积极;", + }, + "usage": { + "total_tokens": 22, + "output_tokens": 2, + "input_tokens": 20, }, - 'usage': { - 'total_tokens': 22, - 'output_tokens': 2, - 'input_tokens': 20 - } } def test_http_text_call(self, mock_server: MockServer): response_str = json.dumps(TestUnderstandingRequest.text_response_obj) mock_server.responses.put(response_str) - response = Understanding.call(model=Understanding.Models.opennlu_v1, - sentence='老师今天表扬我了', - labels='积极,消极', - task='classification') + response = Understanding.call( + model=Understanding.Models.opennlu_v1, + sentence="老师今天表扬我了", + labels="积极,消极", + task="classification", + ) req = mock_server.requests.get(block=True) - assert req['model'] == model - assert req['input']['sentence'] == '老师今天表扬我了' - assert req['input']['labels'] == '积极,消极' - assert req['input']['task'] == 'classification' - assert req['parameters'] == {} + assert req["model"] == model + assert req["input"]["sentence"] == "老师今天表扬我了" + assert req["input"]["labels"] == "积极,消极" + assert req["input"]["task"] == "classification" + assert req["parameters"] == {} assert response.status_code == HTTPStatus.OK - assert response.request_id == 'addddde9-1b8b-9707-979a-09b61f91302d' - assert response.output['rt'] == 0.06531630805693567 - assert response.output['text'] == '积极;' - assert response.usage['total_tokens'] == 22 - assert response.usage['output_tokens'] == 2 - assert response.usage['input_tokens'] == 20 + assert response.request_id == "addddde9-1b8b-9707-979a-09b61f91302d" + assert response.output["rt"] == 0.06531630805693567 + assert response.output["text"] == "积极;" + assert response.usage["total_tokens"] == 22 + assert response.usage["output_tokens"] == 2 + assert response.usage["input_tokens"] == 20 diff --git a/tests/test_video_synthesis.py b/tests/test_video_synthesis.py index 674f822..d845322 100644 --- a/tests/test_video_synthesis.py +++ b/tests/test_video_synthesis.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from http import HTTPStatus @@ -15,18 +16,18 @@ def setup_class(cls): def test_create_task(self): rsp = VideoSynthesis.call( model=VideoSynthesis.Models.wanx_kf2v, - first_frame_url='https://static.dingtalk.com/media/lQLPD2ob9dfKPBvNBADNAkCwOcPjjaFVcEcHqrO8n1BLAA_576_1024.png_620x10000q90.png', - last_frame_url='https://static.dingtalk.com/media/lQLPD2jl0mg85BvNBADNAkCwNJvjWJXBVMwHqrO0OvZlAA_576_1024.png_620x10000q90.png' + first_frame_url="https://static.dingtalk.com/media/lQLPD2ob9dfKPBvNBADNAkCwOcPjjaFVcEcHqrO8n1BLAA_576_1024.png_620x10000q90.png", + last_frame_url="https://static.dingtalk.com/media/lQLPD2jl0mg85BvNBADNAkCwNJvjWJXBVMwHqrO0OvZlAA_576_1024.png_620x10000q90.png", ) assert rsp.status_code == HTTPStatus.OK - assert rsp.output['task_status'] == 'SUCCEEDED' + assert rsp.output["task_status"] == "SUCCEEDED" def test_fetch_status(self): rsp = VideoSynthesis.call( model=VideoSynthesis.Models.wanx_kf2v, - first_frame_url='https://static.dingtalk.com/media/lQLPD2ob9dfKPBvNBADNAkCwOcPjjaFVcEcHqrO8n1BLAA_576_1024.png_620x10000q90.png', - last_frame_url='https://static.dingtalk.com/media/lQLPD2jl0mg85BvNBADNAkCwNJvjWJXBVMwHqrO0OvZlAA_576_1024.png_620x10000q90.png' + first_frame_url="https://static.dingtalk.com/media/lQLPD2ob9dfKPBvNBADNAkCwOcPjjaFVcEcHqrO8n1BLAA_576_1024.png_620x10000q90.png", + last_frame_url="https://static.dingtalk.com/media/lQLPD2jl0mg85BvNBADNAkCwNJvjWJXBVMwHqrO0OvZlAA_576_1024.png_620x10000q90.png", ) assert rsp.status_code == HTTPStatus.OK @@ -36,23 +37,23 @@ def test_fetch_status(self): def test_wait(self): rsp = VideoSynthesis.async_call( model=VideoSynthesis.Models.wanx_kf2v, - first_frame_url='https://static.dingtalk.com/media/lQLPD2ob9dfKPBvNBADNAkCwOcPjjaFVcEcHqrO8n1BLAA_576_1024.png_620x10000q90.png', - last_frame_url='https://static.dingtalk.com/media/lQLPD2jl0mg85BvNBADNAkCwNJvjWJXBVMwHqrO0OvZlAA_576_1024.png_620x10000q90.png' + first_frame_url="https://static.dingtalk.com/media/lQLPD2ob9dfKPBvNBADNAkCwOcPjjaFVcEcHqrO8n1BLAA_576_1024.png_620x10000q90.png", + last_frame_url="https://static.dingtalk.com/media/lQLPD2jl0mg85BvNBADNAkCwNJvjWJXBVMwHqrO0OvZlAA_576_1024.png_620x10000q90.png", ) assert rsp.status_code == HTTPStatus.OK rsp = VideoSynthesis.wait(rsp) assert rsp.status_code == HTTPStatus.OK - assert rsp.output.task_id != '' # verify access by properties. + assert rsp.output.task_id != "" # verify access by properties. assert rsp.output.task_status == TaskStatus.SUCCEEDED - assert rsp.output.video_url != '' + assert rsp.output.video_url != "" - assert rsp.output['task_id'] != '' - assert rsp.output['task_status'] == TaskStatus.SUCCEEDED - assert rsp.output['video_url'] != '' + assert rsp.output["task_id"] != "" + assert rsp.output["task_status"] == TaskStatus.SUCCEEDED + assert rsp.output["video_url"] != "" def test_list_cancel_task(self): - rsp = VideoSynthesis.list(status='CANCELED') + rsp = VideoSynthesis.list(status="CANCELED") assert rsp.status_code == HTTPStatus.OK def test_list_all(self): diff --git a/tests/test_websocket_async_api.py b/tests/test_websocket_async_api.py index e667a57..7551154 100644 --- a/tests/test_websocket_async_api.py +++ b/tests/test_websocket_async_api.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import pytest @@ -9,7 +10,7 @@ from tests.websocket_task_request import WebSocketRequest # set mock server url. -base_websocket_api_url = 'ws://localhost:8080/ws/aigc/v1' +base_websocket_api_url = "ws://localhost:8080/ws/aigc/v1" # text output: text binary out put: image @@ -21,34 +22,36 @@ class TestWebSocketAsyncRequest(BaseTestEnvironment): # test streaming none @pytest.mark.asyncio async def test_streaming_none_text_to_text(self, http_server): - dashscope.base_websocket_api_url = '%s/echo' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/echo" % base_websocket_api_url responses = await WebSocketRequest.aio_call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_none_text_to_text, - prompt='hello', + prompt="hello", stream=self.stream, max_tokens=1024, ws_stream_mode=WebsocketStreamingMode.NONE, is_binary_input=False, - n=50) + n=50, + ) if self.stream: async for resp in responses: - assert resp.output['text'] == 'hello' + assert resp.output["text"] == "hello" else: - assert responses.output['text'] == 'hello' + assert responses.output["text"] == "hello" @pytest.mark.asyncio async def test_streaming_none_text_to_binary(self, http_server): - dashscope.base_websocket_api_url = '%s/out' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/out" % base_websocket_api_url responses = await WebSocketRequest.aio_call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_none_text_to_binary, - prompt='hello', + prompt="hello", stream=self.stream, max_tokens=1024, ws_stream_mode=WebsocketStreamingMode.NONE, is_binary_input=False, - n=50) + n=50, + ) if self.stream: async for resp in responses: assert resp.output == bytes([0x01] * 100) @@ -57,37 +60,41 @@ async def test_streaming_none_text_to_binary(self, http_server): @pytest.mark.asyncio async def test_streaming_none_binary_to_text(self, http_server): - dashscope.base_websocket_api_url = '%s/in' % base_websocket_api_url # noqa E501 + dashscope.base_websocket_api_url = ( + "%s/in" % base_websocket_api_url + ) # noqa E501 video = bytes([0x01] * 100) responses = await WebSocketRequest.aio_call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_none_binary_to_text, prompt=video, stream=self.stream, max_tokens=1024, ws_stream_mode=WebsocketStreamingMode.NONE, is_binary_input=True, - n=50) + n=50, + ) if self.stream: async for resp in responses: - assert resp.output['text'] == 'world' + assert resp.output["text"] == "world" else: - assert responses.output['text'] == 'world' + assert responses.output["text"] == "world" @pytest.mark.asyncio async def test_streaming_none_binary_to_binary(self, http_server): - dashscope.base_websocket_api_url = '%s/inout' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/inout" % base_websocket_api_url video = bytes([0x01] * 100) responses = await WebSocketRequest.aio_call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_none_binary_to_binary, prompt=video, stream=self.stream, max_tokens=1024, ws_stream_mode=WebsocketStreamingMode.NONE, is_binary_input=True, - n=50) + n=50, + ) if self.stream: async for resp in responses: assert resp.output == bytes([0x01] * 100) @@ -97,45 +104,48 @@ async def test_streaming_none_binary_to_binary(self, http_server): # test string in @pytest.mark.asyncio async def test_streaming_in_text_to_text(self, http_server): - dashscope.base_websocket_api_url = '%s/echo' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/echo" % base_websocket_api_url def make_input(): for i in range(10): - yield 'input message %s' % i + yield "input message %s" % i responses = await WebSocketRequest.aio_call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_in_text_to_text, prompt=make_input(), stream=self.stream, max_tokens=1024, ws_stream_mode=WebsocketStreamingMode.IN, is_binary_input=False, - n=50) + n=50, + ) if self.stream: async for resp in responses: - assert resp.output['text'] == 'world' + assert resp.output["text"] == "world" else: - assert responses.output[ - 'text'] == 'world' # echo the input out(input 10) + assert ( + responses.output["text"] == "world" + ) # echo the input out(input 10) @pytest.mark.asyncio async def test_streaming_in_text_to_binary(self, http_server): - dashscope.base_websocket_api_url = '%s/out' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/out" % base_websocket_api_url def make_input(): for i in range(10): - yield 'input message %s' % i + yield "input message %s" % i responses = await WebSocketRequest.aio_call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_in_text_to_binary, prompt=make_input(), stream=self.stream, max_tokens=1024, ws_stream_mode=WebsocketStreamingMode.IN, is_binary_input=False, - n=50) + n=50, + ) if self.stream: async for resp in responses: assert resp.output == bytes([0x01] * 100) @@ -144,45 +154,49 @@ def make_input(): @pytest.mark.asyncio async def test_streaming_in_binary_to_text(self, http_server): - dashscope.base_websocket_api_url = '%s/in' % base_websocket_api_url # noqa E501 + dashscope.base_websocket_api_url = ( + "%s/in" % base_websocket_api_url + ) # noqa E501 def make_input(): for i in range(10): yield bytes([0x01] * 100) responses = await WebSocketRequest.aio_call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_in_binary_to_text, prompt=make_input(), stream=self.stream, max_tokens=1024, ws_stream_mode=WebsocketStreamingMode.IN, is_binary_input=True, - n=50) + n=50, + ) if self.stream: async for resp in responses: - assert resp.output['text'] == 'world' + assert resp.output["text"] == "world" else: - assert responses.output['text'] == 'world' + assert responses.output["text"] == "world" @pytest.mark.asyncio async def test_streaming_in_binary_to_binary(self, http_server): - dashscope.base_websocket_api_url = '%s/inout' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/inout" % base_websocket_api_url def make_input(): for i in range(10): yield bytes([0x01] * 100) responses = await WebSocketRequest.aio_call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_in_binary_to_binary, prompt=make_input(), stream=self.stream, max_tokens=1024, ws_stream_mode=WebsocketStreamingMode.IN, is_binary_input=True, - n=50) + n=50, + ) if self.stream: async for resp in responses: assert resp.output == bytes([0x01] * 100) @@ -192,120 +206,130 @@ def make_input(): # streaming out @pytest.mark.asyncio async def test_streaming_out_text_to_text(self, http_server): - dashscope.base_websocket_api_url = '%s/echo' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/echo" % base_websocket_api_url responses = await WebSocketRequest.aio_call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_out_text_to_text, - prompt='hello', + prompt="hello", stream=self.stream, max_tokens=1024, ws_stream_mode=WebsocketStreamingMode.OUT, is_binary_input=False, - n=50) + n=50, + ) if self.stream: async for resp in responses: - assert resp.output['text'] == 'world' + assert resp.output["text"] == "world" else: - assert responses.output['text'] == 'world' + assert responses.output["text"] == "world" @pytest.mark.asyncio async def test_streaming_out_text_to_binary(self, http_server): - dashscope.base_websocket_api_url = '%s/out' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/out" % base_websocket_api_url responses = await WebSocketRequest.aio_call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_out_text_to_binary, - prompt='hello', + prompt="hello", stream=self.stream, max_tokens=1024, ws_stream_mode=WebsocketStreamingMode.OUT, is_binary_input=False, - n=50) + n=50, + ) if self.stream: async for resp in responses: assert resp.output == bytes([0x01] * 100) else: assert len(responses.output) and responses.output == bytes( - [0x01] * 100) + [0x01] * 100, + ) @pytest.mark.asyncio async def test_streaming_out_binary_to_text(self, http_server): - dashscope.base_websocket_api_url = '%s/in' % base_websocket_api_url # noqa E501 + dashscope.base_websocket_api_url = ( + "%s/in" % base_websocket_api_url + ) # noqa E501 responses = await WebSocketRequest.aio_call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_out_binary_to_text, prompt=bytes([0x01] * 100), max_tokens=1024, stream=self.stream, ws_stream_mode=WebsocketStreamingMode.OUT, is_binary_input=True, - n=50) + n=50, + ) if self.stream: async for resp in responses: - assert resp.output['text'] == 'world' + assert resp.output["text"] == "world" else: - assert responses.output['text'] == 'world' + assert responses.output["text"] == "world" @pytest.mark.asyncio async def test_streaming_out_binary_to_binary(self, http_server): - dashscope.base_websocket_api_url = '%s/inout' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/inout" % base_websocket_api_url responses = await WebSocketRequest.aio_call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_out_binary_to_binary, prompt=bytes([0x01] * 100), stream=self.stream, max_tokens=1024, ws_stream_mode=WebsocketStreamingMode.OUT, is_binary_input=True, - n=50) + n=50, + ) if self.stream: async for resp in responses: assert resp.output == bytes([0x01] * 100) else: assert len(responses.output) == 100 and responses.output == bytes( - [0x01] * 100) + [0x01] * 100, + ) # streaming duplex @pytest.mark.asyncio async def test_streaming_duplex_text_to_text(self, http_server): - dashscope.base_websocket_api_url = '%s/echo' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/echo" % base_websocket_api_url def make_input(): for i in range(10): - yield 'input message %s' % i + yield "input message %s" % i responses = await WebSocketRequest.aio_call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_duplex_text_to_text, prompt=make_input(), stream=self.stream, max_tokens=1024, ws_stream_mode=WebsocketStreamingMode.DUPLEX, is_binary_input=False, - n=50) + n=50, + ) if self.stream: async for resp in responses: - assert resp.output['text'] == 'world' + assert resp.output["text"] == "world" else: - assert responses.output['text'] == 'world' + assert responses.output["text"] == "world" @pytest.mark.asyncio async def test_streaming_duplex_text_to_binary(self, http_server): - dashscope.base_websocket_api_url = '%s/out' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/out" % base_websocket_api_url def make_input(): for i in range(10): - yield 'input message %s' % i + yield "input message %s" % i responses = await WebSocketRequest.aio_call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_duplex_text_to_binary, prompt=make_input(), stream=self.stream, max_tokens=1024, ws_stream_mode=WebsocketStreamingMode.DUPLEX, is_binary_input=False, - n=50) + n=50, + ) if self.stream: async for resp in responses: assert resp.output == bytes([0x01] * 100) @@ -314,44 +338,48 @@ def make_input(): @pytest.mark.asyncio async def test_streaming_duplex_binary_to_text(self, http_server): - dashscope.base_websocket_api_url = '%s/in' % base_websocket_api_url # noqa E501 + dashscope.base_websocket_api_url = ( + "%s/in" % base_websocket_api_url + ) # noqa E501 def make_input(): for i in range(10): yield bytes([0x01] * 100) responses = await WebSocketRequest.aio_call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_duplex_binary_to_text, prompt=make_input(), stream=self.stream, max_tokens=1024, ws_stream_mode=WebsocketStreamingMode.DUPLEX, is_binary_input=True, - n=50) + n=50, + ) if self.stream: async for resp in responses: - assert resp.output['text'] == 'world' + assert resp.output["text"] == "world" else: - assert responses.output['text'] == 'world' + assert responses.output["text"] == "world" @pytest.mark.asyncio async def test_streaming_duplex_binary_to_binary(self, http_server): - dashscope.base_websocket_api_url = '%s/inout' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/inout" % base_websocket_api_url def make_input(): for i in range(10): yield bytes([0x01] * 100) responses = await WebSocketRequest.aio_call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_duplex_binary_to_binary, prompt=make_input(), stream=self.stream, max_tokens=1024, ws_stream_mode=WebsocketStreamingMode.DUPLEX, is_binary_input=True, - n=50) + n=50, + ) if self.stream: async for resp in responses: assert resp.output == bytes([0x01] * 100) diff --git a/tests/test_websocket_parameters.py b/tests/test_websocket_parameters.py index 72d809f..b2c1a7e 100644 --- a/tests/test_websocket_parameters.py +++ b/tests/test_websocket_parameters.py @@ -1,10 +1,14 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import dashscope from dashscope.protocol.websocket import WebsocketStreamingMode from tests.base_test import BaseTestEnvironment -from tests.constants import (TEST_DISABLE_DATA_INSPECTION_REQUEST_ID, - TEST_ENABLE_DATA_INSPECTION_REQUEST_ID, TestTasks) +from tests.constants import ( + TEST_DISABLE_DATA_INSPECTION_REQUEST_ID, + TEST_ENABLE_DATA_INSPECTION_REQUEST_ID, + TestTasks, +) from tests.websocket_task_request import WebSocketRequest @@ -16,19 +20,19 @@ def pytest_generate_tests(metafunc): items = scenario[1].items() argnames = [x[0] for x in items] argvalues.append([x[1] for x in items]) - metafunc.parametrize(argnames, argvalues, ids=idlist, scope='class') + metafunc.parametrize(argnames, argvalues, ids=idlist, scope="class") -batch = ('batch', {'stream': False}) -stream = ('stream', {'stream': True}) +batch = ("batch", {"stream": False}) +stream = ("stream", {"stream": True}) def request_generator(): - yield 'hello' + yield "hello" # set mock server url. -base_websocket_api_url = 'ws://localhost:8080/ws/aigc/v1' +base_websocket_api_url = "ws://localhost:8080/ws/aigc/v1" # text output: text binary out put: image @@ -38,57 +42,68 @@ class TestWebSocketSyncRequest(BaseTestEnvironment): # test streaming none def test_default_disable_data_inspection(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/echo' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/echo" % base_websocket_api_url responses = WebSocketRequest.call( - model='qwen-turbo', - prompt='hello', + model="qwen-turbo", + prompt="hello", task=TestTasks.streaming_none_text_to_text, stream=stream, ws_stream_mode=WebsocketStreamingMode.NONE, is_binary_input=False, max_tokens=1024, n=50, - headers={'request_id': TEST_DISABLE_DATA_INSPECTION_REQUEST_ID}) + headers={"request_id": TEST_DISABLE_DATA_INSPECTION_REQUEST_ID}, + ) if stream: for resp in responses: - assert resp.output['text'] == 'hello' + assert resp.output["text"] == "hello" else: - assert responses.output['text'] == 'hello' + assert responses.output["text"] == "hello" - def test_disable_data_inspection(self, stream, http_server, - mock_disable_data_inspection_env): - dashscope.base_websocket_api_url = '%s/echo' % base_websocket_api_url + def test_disable_data_inspection( + self, + stream, + http_server, + mock_disable_data_inspection_env, + ): + dashscope.base_websocket_api_url = "%s/echo" % base_websocket_api_url responses = WebSocketRequest.call( - model='qwen-turbo', - prompt='hello', + model="qwen-turbo", + prompt="hello", task=TestTasks.streaming_none_text_to_text, stream=stream, ws_stream_mode=WebsocketStreamingMode.NONE, is_binary_input=False, max_tokens=1024, n=50, - headers={'request_id': TEST_DISABLE_DATA_INSPECTION_REQUEST_ID}) + headers={"request_id": TEST_DISABLE_DATA_INSPECTION_REQUEST_ID}, + ) if stream: for resp in responses: - assert resp.output['text'] == 'hello' + assert resp.output["text"] == "hello" else: - assert responses.output['text'] == 'hello' + assert responses.output["text"] == "hello" - def test_enable_data_inspection_by_env(self, stream, http_server, - mock_enable_data_inspection_env): - dashscope.base_websocket_api_url = '%s/echo' % base_websocket_api_url + def test_enable_data_inspection_by_env( + self, + stream, + http_server, + mock_enable_data_inspection_env, + ): + dashscope.base_websocket_api_url = "%s/echo" % base_websocket_api_url responses = WebSocketRequest.call( - model='qwen-turbo', - prompt='hello', + model="qwen-turbo", + prompt="hello", task=TestTasks.streaming_none_text_to_text, stream=stream, ws_stream_mode=WebsocketStreamingMode.NONE, is_binary_input=False, max_tokens=1024, n=50, - headers={'request_id': TEST_ENABLE_DATA_INSPECTION_REQUEST_ID}) + headers={"request_id": TEST_ENABLE_DATA_INSPECTION_REQUEST_ID}, + ) if stream: for resp in responses: - assert resp.output['text'] == 'hello' + assert resp.output["text"] == "hello" else: - assert responses.output['text'] == 'hello' + assert responses.output["text"] == "hello" diff --git a/tests/test_websocket_sync_api.py b/tests/test_websocket_sync_api.py index d1c631c..864d916 100644 --- a/tests/test_websocket_sync_api.py +++ b/tests/test_websocket_sync_api.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import dashscope @@ -15,19 +16,19 @@ def pytest_generate_tests(metafunc): items = scenario[1].items() argnames = [x[0] for x in items] argvalues.append([x[1] for x in items]) - metafunc.parametrize(argnames, argvalues, ids=idlist, scope='class') + metafunc.parametrize(argnames, argvalues, ids=idlist, scope="class") -batch = ('batch', {'stream': False}) -stream = ('stream', {'stream': True}) +batch = ("batch", {"stream": False}) +stream = ("stream", {"stream": True}) def request_generator(): - yield 'hello' + yield "hello" # set mock server url. -base_websocket_api_url = 'ws://localhost:8080/ws/aigc/v1' +base_websocket_api_url = "ws://localhost:8080/ws/aigc/v1" # text output: text binary out put: image @@ -37,33 +38,35 @@ class TestWebSocketSyncRequest(BaseTestEnvironment): # test streaming none def test_streaming_none_text_to_text(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/echo' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/echo" % base_websocket_api_url responses = WebSocketRequest.call( - model='qwen-turbo', - prompt='hello', + model="qwen-turbo", + prompt="hello", task=TestTasks.streaming_none_text_to_text, stream=stream, ws_stream_mode=WebsocketStreamingMode.NONE, is_binary_input=False, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: - assert resp.output['text'] == 'hello' + assert resp.output["text"] == "hello" else: - assert responses.output['text'] == 'hello' + assert responses.output["text"] == "hello" def test_streaming_none_text_to_binary(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/out' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/out" % base_websocket_api_url responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_none_text_to_binary, - prompt='hello', + prompt="hello", stream=stream, ws_stream_mode=WebsocketStreamingMode.NONE, is_binary_input=False, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: assert resp.output == bytes([0x01] * 100) @@ -71,36 +74,38 @@ def test_streaming_none_text_to_binary(self, stream, http_server): assert responses.output == bytes([0x01] * 100) def test_streaming_none_binary_to_text(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/in' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/in" % base_websocket_api_url video = bytes([0x01] * 100) responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_none_binary_to_text, prompt=video, stream=stream, ws_stream_mode=WebsocketStreamingMode.NONE, is_binary_input=True, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: - assert resp.output['text'] == 'world' + assert resp.output["text"] == "world" else: - assert responses.output['text'] == 'world' + assert responses.output["text"] == "world" def test_streaming_none_binary_to_binary(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/inout' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/inout" % base_websocket_api_url video = bytes([0x01] * 100) responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_none_binary_to_binary, prompt=video, stream=stream, ws_stream_mode=WebsocketStreamingMode.NONE, is_binary_input=True, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: assert resp.output == bytes([0x01] * 100) @@ -109,21 +114,22 @@ def test_streaming_none_binary_to_binary(self, stream, http_server): # test string in def test_streaming_in_text_to_text(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/echo' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/echo" % base_websocket_api_url def make_input(): for i in range(10): - yield 'input message %s' % i + yield "input message %s" % i responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_in_text_to_text, prompt=make_input(), stream=stream, ws_stream_mode=WebsocketStreamingMode.IN, is_binary_input=False, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: assert len(resp.output) == 1 @@ -131,21 +137,22 @@ def make_input(): assert len(responses.output) == 1 # echo the input out. def test_streaming_in_text_to_binary(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/out' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/out" % base_websocket_api_url def make_input(): for i in range(10): - yield 'input message %s' % i + yield "input message %s" % i responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_in_text_to_binary, prompt=make_input(), stream=stream, ws_stream_mode=WebsocketStreamingMode.IN, is_binary_input=False, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: assert resp.output == bytes([0x01] * 100) @@ -153,18 +160,19 @@ def make_input(): assert responses.output == bytes([0x01] * 100) def test_streaming_in_text_to_text_with_file(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/echo' % base_websocket_api_url - text_file = open('tests/data/multi_line.txt', encoding='utf-8') + dashscope.base_websocket_api_url = "%s/echo" % base_websocket_api_url + text_file = open("tests/data/multi_line.txt", encoding="utf-8") responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_in_text_to_text, prompt=text_file, stream=stream, ws_stream_mode=WebsocketStreamingMode.IN, is_binary_input=False, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: assert len(resp.output) == 1 @@ -172,21 +180,22 @@ def test_streaming_in_text_to_text_with_file(self, stream, http_server): assert len(responses.output) == 1 # echo the input out. def test_streaming_in_text_to_binary_generator(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/out' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/out" % base_websocket_api_url def make_input(): for i in range(10): - yield 'input message %s' % i + yield "input message %s" % i responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_in_text_to_binary, prompt=make_input(), stream=stream, ws_stream_mode=WebsocketStreamingMode.IN, is_binary_input=False, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: assert resp.output == bytes([0x01] * 100) @@ -194,44 +203,46 @@ def make_input(): assert responses.output == bytes([0x01] * 100) def test_streaming_in_binary_to_text(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/in' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/in" % base_websocket_api_url def make_input(): for i in range(10): yield bytes([0x01] * 100) responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_in_binary_to_text, prompt=make_input(), stream=stream, ws_stream_mode=WebsocketStreamingMode.IN, is_binary_input=True, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: - assert resp.output['text'] == 'world' + assert resp.output["text"] == "world" else: - assert responses.output['text'] == 'world' + assert responses.output["text"] == "world" def test_streaming_in_binary_to_binary(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/inout' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/inout" % base_websocket_api_url def make_input(): for i in range(10): yield bytes([0x01] * 100) responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_in_binary_to_binary, prompt=make_input(), stream=stream, ws_stream_mode=WebsocketStreamingMode.IN, is_binary_input=True, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: assert resp.output == bytes([0x01] * 100) @@ -240,50 +251,53 @@ def make_input(): # streaming out def test_streaming_out_text_to_text(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/echo' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/echo" % base_websocket_api_url responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_out_text_to_text, - prompt='hello', + prompt="hello", stream=stream, ws_stream_mode=WebsocketStreamingMode.OUT, is_binary_input=False, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: - assert resp.output['text'] == 'world' + assert resp.output["text"] == "world" else: - responses.output['text'] == 'world' + responses.output["text"] == "world" def test_streaming_out_text_to_text_stream(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/echo' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/echo" % base_websocket_api_url responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_out_text_to_text, - prompt='hello', + prompt="hello", stream=stream, ws_stream_mode=WebsocketStreamingMode.OUT, is_binary_input=False, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: - assert resp.output['text'] == 'world' + assert resp.output["text"] == "world" else: - assert responses.output['text'] == 'hello' + assert responses.output["text"] == "hello" def test_streaming_out_text_to_binary(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/out' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/out" % base_websocket_api_url responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_out_text_to_binary, - prompt='hello', + prompt="hello", stream=stream, ws_stream_mode=WebsocketStreamingMode.OUT, is_binary_input=False, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: assert resp.output == bytes([0x01] * 100) @@ -291,34 +305,36 @@ def test_streaming_out_text_to_binary(self, stream, http_server): assert responses.output == bytes([0x01] * 100) def test_streaming_out_binary_to_text(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/in' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/in" % base_websocket_api_url responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_out_binary_to_text, prompt=bytes([0x01] * 100), max_tokens=1024, stream=stream, ws_stream_mode=WebsocketStreamingMode.OUT, is_binary_input=True, - n=50) + n=50, + ) if stream: for resp in responses: - assert resp.output['text'] == 'world' + assert resp.output["text"] == "world" else: - assert responses.output['text'] == 'world' + assert responses.output["text"] == "world" def test_streaming_out_binary_to_binary(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/inout' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/inout" % base_websocket_api_url responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_out_binary_to_binary, prompt=bytes([0x01] * 100), stream=stream, ws_stream_mode=WebsocketStreamingMode.OUT, is_binary_input=True, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: assert resp.output == bytes([0x01] * 100) @@ -327,43 +343,45 @@ def test_streaming_out_binary_to_binary(self, stream, http_server): # streaming duplex def test_streaming_duplex_text_to_text(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/echo' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/echo" % base_websocket_api_url def make_input(): for i in range(10): - yield 'input message %s' % i + yield "input message %s" % i responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_duplex_text_to_text, prompt=make_input(), stream=stream, ws_stream_mode=WebsocketStreamingMode.DUPLEX, is_binary_input=False, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: - assert resp.output['text'] == 'world' + assert resp.output["text"] == "world" else: - assert responses.output['text'] == 'world' + assert responses.output["text"] == "world" def test_streaming_duplex_text_to_binary(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/out' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/out" % base_websocket_api_url def make_input(): for i in range(10): - yield 'input message %s' % i + yield "input message %s" % i responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_duplex_text_to_binary, prompt=make_input(), stream=stream, ws_stream_mode=WebsocketStreamingMode.DUPLEX, is_binary_input=False, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: @@ -372,43 +390,45 @@ def make_input(): assert responses.output == bytes([0x01] * 100) def test_streaming_duplex_binary_to_text(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/in' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/in" % base_websocket_api_url def make_input(): for i in range(10): yield bytes([0x01] * 100) responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_duplex_binary_to_text, prompt=make_input(), stream=stream, ws_stream_mode=WebsocketStreamingMode.DUPLEX, is_binary_input=True, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: - assert resp.output['text'] == 'world' + assert resp.output["text"] == "world" else: - assert responses.output['text'] == 'world' + assert responses.output["text"] == "world" def test_streaming_duplex_binary_to_binary(self, stream, http_server): - dashscope.base_websocket_api_url = '%s/inout' % base_websocket_api_url + dashscope.base_websocket_api_url = "%s/inout" % base_websocket_api_url def make_input(): for i in range(10): yield bytes([0x01] * 100) responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_duplex_binary_to_binary, prompt=make_input(), stream=stream, ws_stream_mode=WebsocketStreamingMode.DUPLEX, is_binary_input=True, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: assert resp.output == bytes([0x01] * 100) @@ -416,19 +436,25 @@ def make_input(): assert responses.output == bytes([0x01] * 100) def test_streaming_duplex_binary_to_binary_with_input_file( - self, stream, http_server): - dashscope.base_websocket_api_url = '%s/inout' % base_websocket_api_url - binary_file = open('tests/data/action_recognition_test_video.mp4', - 'rb') # TODO no rb + self, + stream, + http_server, + ): + dashscope.base_websocket_api_url = "%s/inout" % base_websocket_api_url + binary_file = open( + "tests/data/action_recognition_test_video.mp4", + "rb", + ) # TODO no rb responses = WebSocketRequest.call( - model='qwen-turbo', + model="qwen-turbo", task=TestTasks.streaming_duplex_binary_to_binary, prompt=binary_file, stream=stream, ws_stream_mode=WebsocketStreamingMode.DUPLEX, is_binary_input=True, max_tokens=1024, - n=50) + n=50, + ) if stream: for resp in responses: assert resp.output == bytes([0x01] * 100) diff --git a/tests/websocket_mock_server_task_handler.py b/tests/websocket_mock_server_task_handler.py index f1ecc88..742d4bd 100644 --- a/tests/websocket_mock_server_task_handler.py +++ b/tests/websocket_mock_server_task_handler.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import asyncio @@ -6,18 +7,32 @@ import aiohttp from dashscope.common.error import UnexpectedMessageReceived -from dashscope.protocol.websocket import (ACTION_KEY, EVENT_KEY, ActionType, - EventType, WebsocketStreamingMode) +from dashscope.protocol.websocket import ( + ACTION_KEY, + EVENT_KEY, + ActionType, + EventType, + WebsocketStreamingMode, +) class WebSocketTaskProcessor: - """WebSocket state machine. - """ - def __init__(self, ws, task_id, streaming_mode, model, task, is_binary_in, - is_binary_out, run_task_json_message) -> None: + """WebSocket state machine.""" + + def __init__( + self, + ws, + task_id, + streaming_mode, + model, + task, + is_binary_in, + is_binary_out, + run_task_json_message, + ) -> None: self.ws = ws - self.error = '' - self.error_message = '' + self.error = "" + self.error_message = "" self.task_id = task_id self.model = model self.task = task @@ -28,14 +43,14 @@ def __init__(self, ws, task_id, streaming_mode, model, task, is_binary_in, self._duplex_task_finished = False async def aio_call(self): - await self._send_start_event( - ) # no matter what, send start event first. + await self._send_start_event() # no matter what, send start event first. if self.streaming_mode == WebsocketStreamingMode.NONE: # if binary data, we need to receive data if self.is_binary_in: - binary_data = await self._receive_batch_binary( + binary_data = ( + await self._receive_batch_binary() ) # ignore timeout. - print('Receive binary data, length: %s' % len(binary_data)) + print("Receive binary data, length: %s" % len(binary_data)) # send "event":"task-finished" if self.is_binary_out: # send binary data @@ -44,27 +59,29 @@ async def aio_call(self): elif self.is_binary_in: return await self._send_task_finished( payload={ - 'output': { - 'text': 'world' + "output": { + "text": "world", + }, + "usage": { + "input_tokens": 1, + "output_tokens": 200, }, - 'usage': { - 'input_tokens': 1, - 'output_tokens': 200 - } - }) # binary input, result with world. + }, + ) # binary input, result with world. else: await self._send_task_finished( payload={ - 'output': { - 'text': - self.run_task_json_message['payload']['input'] - ['prompt'] + "output": { + "text": self.run_task_json_message["payload"][ + "input" + ]["prompt"], }, - 'usage': { - 'input_tokens': 100, - 'output_tokens': 200 - } # noqa E501 - }) # for echo message out. + "usage": { + "input_tokens": 100, + "output_tokens": 200, + }, # noqa E501 + }, + ) # for echo message out. elif self.streaming_mode == WebsocketStreamingMode.IN: if self.is_binary_in: # binary data @@ -79,25 +96,27 @@ async def aio_call(self): if self.is_binary_in: return await self._send_task_finished( payload={ - 'output': { - 'text': 'world' + "output": { + "text": "world", }, - 'usage': { - 'input_tokens': 1, - 'output_tokens': 200 - } - }) # noqa E501 + "usage": { + "input_tokens": 1, + "output_tokens": 200, + }, + }, + ) # noqa E501 else: await self._send_task_finished( payload={ - 'output': { - 'text': 'world' + "output": { + "text": "world", }, - 'usage': { - 'input_tokens': 1, - 'output_tokens': 200 - } - }) + "usage": { + "input_tokens": 1, + "output_tokens": 200, + }, + }, + ) elif self.streaming_mode == WebsocketStreamingMode.OUT: if self.is_binary_in: # run task without data, data is binary. binary_data = await self._receive_batch_binary() @@ -115,16 +134,20 @@ async def aio_call(self): else: # duplex mode if self.is_binary_in: send_task = asyncio.create_task( - self._receive_streaming_binary_data()) + self._receive_streaming_binary_data(), + ) else: send_task = asyncio.create_task( - self._receive_streaming_text_data()) + self._receive_streaming_text_data(), + ) if self.is_binary_out: receive_task = asyncio.create_task( - self.send_streaming_binary_output()) + self.send_streaming_binary_output(), + ) else: receive_task = asyncio.create_task( - self.send_streaming_text_output()) + self.send_streaming_text_output(), + ) _, _ = await asyncio.gather(receive_task, send_task) @@ -137,38 +160,38 @@ async def send_streaming_binary_output(self): async def send_streaming_text_output(self): headers = { - 'task_id': self.task_id, - 'event': 'result-generated', + "task_id": self.task_id, + "event": "result-generated", } for i in range(10): payload = { - 'output': { - 'text': 'world' + "output": { + "text": "world", + }, + "usage": { + "input_tokens": 10, + "output_tokens": 20, }, - 'usage': { - 'input_tokens': 10, - 'output_tokens': 20 - } } msg = self._build_up_message(headers=headers, payload=payload) await self.ws.send_str(msg) - print('send_streaming_text_output finished!') + print("send_streaming_text_output finished!") async def send_batch_streaming_output(self): data = bytes([0x01] * 100) await self.ws.send_bytes(data) async def _send_start_event(self): - headers = {'task_id': self.task_id, EVENT_KEY: EventType.STARTED} + headers = {"task_id": self.task_id, EVENT_KEY: EventType.STARTED} payload = {} message = self._build_up_message(headers, payload=payload) - print('sending task started event message: %s' % message) + print("sending task started event message: %s" % message) await self.ws.send_str(message) async def _send_task_finished(self, payload): - headers = {'task_id': self.task_id, EVENT_KEY: EventType.FINISHED} + headers = {"task_id": self.task_id, EVENT_KEY: EventType.FINISHED} message = self._build_up_message(headers, payload) - print('sending task finished message: %s' % message) + print("sending task finished message: %s" % message) await self.ws.send_str(message) async def _receive_streaming_binary_data(self): @@ -178,44 +201,46 @@ async def _receive_streaming_binary_data(self): return if msg.type == aiohttp.WSMsgType.BINARY: print( - 'Receive binary data length: %s' % - len(msg.data)) # real server need return data and process. + "Receive binary data length: %s" % len(msg.data), + ) # real server need return data and process. elif msg.type == aiohttp.WSMsgType.TEXT: req = msg.json() - print('Receive %s event' % req['header'][ACTION_KEY]) - if req['header'][ACTION_KEY] == ActionType.FINISHED: + print("Receive %s event" % req["header"][ACTION_KEY]) + if req["header"][ACTION_KEY] == ActionType.FINISHED: self._duplex_task_finished = True break else: - print('Unknown message: %s' % msg) + print("Unknown message: %s" % msg) else: raise UnexpectedMessageReceived( - 'Expect binary data but receive %s!' % msg.type) + "Expect binary data but receive %s!" % msg.type, + ) async def _receive_streaming_text_data(self): payload = [] - payload.append(self.run_task_json_message['payload']['input']) + payload.append(self.run_task_json_message["payload"]["input"]) while True: msg = await self.ws.receive() if await self.validate_message(msg): return if msg.type == aiohttp.WSMsgType.TEXT: msg_json = msg.json() - print('Receive %s event' % msg_json['header'][ACTION_KEY]) - if msg_json['header'][ACTION_KEY] == ActionType.CONTINUE: - print('Receive text data: ' % msg_json['payload']) - payload.append(msg_json['payload']) - elif msg_json['header'][ACTION_KEY] == ActionType.FINISHED: - print('Receive text data: ' % msg_json['payload']) - if msg_json['payload']: - payload.append(msg_json['payload']) + print("Receive %s event" % msg_json["header"][ACTION_KEY]) + if msg_json["header"][ACTION_KEY] == ActionType.CONTINUE: + print("Receive text data: " % msg_json["payload"]) + payload.append(msg_json["payload"]) + elif msg_json["header"][ACTION_KEY] == ActionType.FINISHED: + print("Receive text data: " % msg_json["payload"]) + if msg_json["payload"]: + payload.append(msg_json["payload"]) self._duplex_task_finished = True return payload else: - print('Unknown message: %s' % msg_json) + print("Unknown message: %s" % msg_json) else: raise UnexpectedMessageReceived( - 'Expect binary data but receive %s!' % msg.type) + "Expect binary data but receive %s!" % msg.type, + ) async def _receive_batch_binary(self): """If the data is not binary, data is send in start package. @@ -233,7 +258,8 @@ async def _receive_batch_binary(self): return msg.data else: raise UnexpectedMessageReceived( - 'Expect binary data but receive %s!' % msg.type) + "Expect binary data but receive %s!" % msg.type, + ) async def _receive_batch_text(self): """If the data is not binary, data is send in start package. @@ -243,33 +269,33 @@ async def _receive_batch_text(self): Returns: No: """ - final_data = self.run_task_json_message['payload'] + final_data = self.run_task_json_message["payload"] while True: msg = await self.ws.receive() if self.validate_message(): break if msg.type == aiohttp.WSMsgType.TEXT: req = msg.json() - print('Receive %s event' % req['header'][ACTION_KEY]) - if req['header'][ACTION_KEY] == ActionType.START: - print('receive start task event') - elif req['header'][ACTION_KEY] == ActionType.FINISHED: + print("Receive %s event" % req["header"][ACTION_KEY]) + if req["header"][ACTION_KEY] == ActionType.START: + print("receive start task event") + elif req["header"][ACTION_KEY] == ActionType.FINISHED: # client is finished, send finished task binary response. await self._send_task_finished(final_data) break else: - print('Unknown message: %s' % msg) + print("Unknown message: %s" % msg) else: - raise UnexpectedMessageReceived('Expect text %s!' % msg.type) + raise UnexpectedMessageReceived("Expect text %s!" % msg.type) def _build_up_message(self, headers, payload): - message = {'header': headers, 'payload': payload} + message = {"header": headers, "payload": payload} return json.dumps(message) async def validate_message(self, msg): if msg.type == aiohttp.WSMsgType.CLOSED: - print('Client close the connection') + print("Client close the connection") elif msg.type == aiohttp.WSMsgType.ERROR: - print('Connection error: %s' % msg.data) + print("Connection error: %s" % msg.data) return True return False diff --git a/tests/websocket_task_request.py b/tests/websocket_task_request.py index 24a308d..0c3d17f 100644 --- a/tests/websocket_task_request.py +++ b/tests/websocket_task_request.py @@ -1,56 +1,66 @@ +# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -from dashscope.api_entities.dashscope_response import (DashScopeAPIResponse, - GenerationResponse) +from dashscope.api_entities.dashscope_response import ( + DashScopeAPIResponse, + GenerationResponse, +) from dashscope.client.base_api import BaseAioApi, BaseApi from dashscope.common.constants import ApiProtocol from dashscope.protocol.websocket import WebsocketStreamingMode class WebSocketRequest(BaseApi, BaseAioApi): - """API for AI-Generated Content(AIGC) models. + """API for AI-Generated Content(AIGC) models.""" - """ @classmethod - async def aio_call(cls, - model: str, - prompt: str, - task: str, - task_group: str = 'aigc', - api_key: str = None, - api_protocol=ApiProtocol.WEBSOCKET, - ws_stream_mode=WebsocketStreamingMode.NONE, - is_binary_input=False, - **kwargs) -> DashScopeAPIResponse: - return await BaseAioApi.call(model=model, - task_group=task_group, - task=task, - api_key=api_key, - input={'prompt': prompt}, - api_protocol=api_protocol, - ws_stream_mode=ws_stream_mode, - is_binary_input=is_binary_input, - **kwargs) + async def aio_call( + cls, + model: str, + prompt: str, + task: str, + task_group: str = "aigc", + api_key: str = None, + api_protocol=ApiProtocol.WEBSOCKET, + ws_stream_mode=WebsocketStreamingMode.NONE, + is_binary_input=False, + **kwargs, + ) -> DashScopeAPIResponse: + return await BaseAioApi.call( + model=model, + task_group=task_group, + task=task, + api_key=api_key, + input={"prompt": prompt}, + api_protocol=api_protocol, + ws_stream_mode=ws_stream_mode, + is_binary_input=is_binary_input, + **kwargs, + ) @classmethod - def call(cls, - model: str, - prompt: str, - task: str, - task_group: str = 'aigc', - api_key: str = None, - api_protocol=ApiProtocol.WEBSOCKET, - ws_stream_mode=WebsocketStreamingMode.NONE, - is_binary_input=False, - **kwargs) -> GenerationResponse: - response = BaseApi.call(model=model, - task_group=task_group, - task=task, - api_key=api_key, - input={'prompt': prompt}, - api_protocol=api_protocol, - ws_stream_mode=ws_stream_mode, - is_binary_input=is_binary_input, - **kwargs) + def call( + cls, + model: str, + prompt: str, + task: str, + task_group: str = "aigc", + api_key: str = None, + api_protocol=ApiProtocol.WEBSOCKET, + ws_stream_mode=WebsocketStreamingMode.NONE, + is_binary_input=False, + **kwargs, + ) -> GenerationResponse: + response = BaseApi.call( + model=model, + task_group=task_group, + task=task, + api_key=api_key, + input={"prompt": prompt}, + api_protocol=api_protocol, + ws_stream_mode=ws_stream_mode, + is_binary_input=is_binary_input, + **kwargs, + ) return response