From e9ddce5cabd3286e7965637098f6b3dde2ee16e2 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 16 May 2024 05:58:46 +0800 Subject: [PATCH] [Frontend] Re-enable custom roles in Chat Completions API (#4758) --- tests/entrypoints/test_openai_server.py | 30 +++++++++++ vllm/entrypoints/openai/protocol.py | 38 +++++++++++++- vllm/entrypoints/openai/serving_chat.py | 66 ++++++++++++++++--------- 3 files changed, 108 insertions(+), 26 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index ee2f034fd2c4..1b04e3205c4b 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -783,6 +783,36 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI): assert content == "2" +async def test_custom_role(server, client: openai.AsyncOpenAI): + # Not sure how the model handles custom roles so we just check that + # both string and complex message content are handled in the same way + + resp1 = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "my-custom-role", + "content": "what is 1+1?", + }], # type: ignore + temperature=0, + seed=0) + + resp2 = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "my-custom-role", + "content": [{ + "type": "text", + "text": "what is 1+1?" + }] + }], # type: ignore + temperature=0, + seed=0) + + content1 = resp1.choices[0].message.content + content2 = resp2.choices[0].message.content + assert content1 == content2 + + async def test_guided_grammar(server, client: openai.AsyncOpenAI): simple_sql_grammar = """ start: select_statement diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 139c5716c7ce..35dfa09ac12b 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -3,16 +3,50 @@ import time from typing import Any, Dict, List, Literal, Optional, Union +import openai.types.chat import torch -from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel, ConfigDict, Field, model_validator -from typing_extensions import Annotated +# pydantic needs the TypedDict from typing_extensions +from typing_extensions import Annotated, Required, TypedDict from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid +class CustomChatCompletionContentPartParam(TypedDict, total=False): + __pydantic_config__ = ConfigDict(extra="allow") # type: ignore + + type: Required[str] + """The type of the content part.""" + + +ChatCompletionContentPartParam = Union[ + openai.types.chat.ChatCompletionContentPartParam, + CustomChatCompletionContentPartParam] + + +class CustomChatCompletionMessageParam(TypedDict, total=False): + """Enables custom roles in the Chat Completion API.""" + role: Required[str] + """The role of the message's author.""" + + content: Union[str, List[ChatCompletionContentPartParam]] + """The contents of the message.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the + same role. + """ + + +ChatCompletionMessageParam = Union[ + openai.types.chat.ChatCompletionMessageParam, + CustomChatCompletionMessageParam] + + class OpenAIBaseModel(BaseModel): # OpenAI API does not allow extra fields model_config = ConfigDict(extra="forbid") diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1b469fc59b07..65824a2206be 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,15 +1,16 @@ import codecs import time -from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List, - Optional, Tuple, TypedDict, Union, final) +from dataclasses import dataclass +from typing import (AsyncGenerator, AsyncIterator, Iterable, List, Optional, + TypedDict, Union, cast, final) from fastapi import Request -from openai.types.chat import (ChatCompletionContentPartParam, - ChatCompletionRole) +from openai.types.chat import ChatCompletionContentPartTextParam from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( + ChatCompletionContentPartParam, ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, @@ -31,6 +32,11 @@ class ConversationMessage(TypedDict): content: str +@dataclass(frozen=True) +class ChatMessageParseResult: + messages: List[ConversationMessage] + + class OpenAIServingChat(OpenAIServing): def __init__(self, @@ -77,27 +83,40 @@ def _load_chat_template(self, chat_template: Optional[str]): logger.warning( "No chat template provided. Chat API will not work.") - def _parse_chat_message_content( + def _parse_chat_message_content_parts( self, - role: ChatCompletionRole, - content: Optional[Union[str, - Iterable[ChatCompletionContentPartParam]]], - ) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]: - if content is None: - return [], [] - if isinstance(content, str): - return [ConversationMessage(role=role, content=content)], [] - + role: str, + parts: Iterable[ChatCompletionContentPartParam], + ) -> ChatMessageParseResult: texts: List[str] = [] - for _, part in enumerate(content): - if part["type"] == "text": - text = part["text"] + + for _, part in enumerate(parts): + part_type = part["type"] + if part_type == "text": + text = cast(ChatCompletionContentPartTextParam, part)["text"] texts.append(text) else: - raise NotImplementedError(f"Unknown part type: {part['type']}") + raise NotImplementedError(f"Unknown part type: {part_type}") + + messages = [ConversationMessage(role=role, content="\n".join(texts))] + + return ChatMessageParseResult(messages=messages) + + def _parse_chat_message_content( + self, + message: ChatCompletionMessageParam, + ) -> ChatMessageParseResult: + role = message["role"] + content = message.get("content") + + if content is None: + return ChatMessageParseResult(messages=[]) + if isinstance(content, str): + messages = [ConversationMessage(role=role, content=content)] + return ChatMessageParseResult(messages=messages) - return [ConversationMessage(role=role, content="\n".join(texts))], [] + return self._parse_chat_message_content_parts(role, content) async def create_chat_completion( self, request: ChatCompletionRequest, raw_request: Request @@ -119,11 +138,10 @@ async def create_chat_completion( try: conversation: List[ConversationMessage] = [] - for m in request.messages: - messages, _ = self._parse_chat_message_content( - m["role"], m["content"]) + for msg in request.messages: + parsed_msg = self._parse_chat_message_content(msg) - conversation.extend(messages) + conversation.extend(parsed_msg.messages) prompt = self.tokenizer.apply_chat_template( conversation=conversation, @@ -387,4 +405,4 @@ async def chat_completion_full_generator( usage=usage, ) - return response \ No newline at end of file + return response