Skip to content

Commit

Permalink
[BUGFIX] [FRONTEND] Correct chat logprobs (vllm-project#5029)
Browse files Browse the repository at this point in the history
Co-authored-by: Breno Faria <breno.faria@intrafind.com>
  • Loading branch information
2 people authored and dtrifiro committed May 31, 2024
1 parent 895db5d commit 110124d
Show file tree
Hide file tree
Showing 6 changed files with 361 additions and 98 deletions.
6 changes: 4 additions & 2 deletions tests/async_engine/test_openapi_server_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
chat_completion.choices) == 1
assert chat_completion.choices[0].message is not None
assert chat_completion.choices[0].logprobs is not None
assert chat_completion.choices[0].logprobs.top_logprobs is not None
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
assert chat_completion.choices[0].logprobs.content[
0].top_logprobs is not None
assert len(
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
Expand Down
209 changes: 181 additions & 28 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,26 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
completion.choices[0].text) >= 5


@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_no_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=None,
)
choice = completion.choices[0]
assert choice.logprobs is None


@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
Expand All @@ -203,7 +223,72 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
choice = completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is None
assert choice.logprobs.top_logprobs is not None
assert len(choice.logprobs.top_logprobs[0]) <= 1


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_some_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=5,
)
choice = completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert len(choice.logprobs.top_logprobs[0]) <= 6


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):

with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=6,
)
...
with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
stream = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=6,
stream=True,
)
async for chunk in stream:
...

# the server should still work afterwards
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
completion = completion.choices[0].text
assert completion is not None and len(completion) >= 0


@pytest.mark.asyncio
Expand Down Expand Up @@ -233,8 +318,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
chat_completion.choices) == 1
assert chat_completion.choices[0].message is not None
assert chat_completion.choices[0].logprobs is not None
assert chat_completion.choices[0].logprobs.top_logprobs is not None
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
assert chat_completion.choices[0].logprobs.content[
0].top_logprobs is not None
assert len(
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
Expand All @@ -251,10 +338,93 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0


@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_no_logprobs_chat(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]

chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=5,
temperature=0.0,
logprobs=False)

choice = chat_completion.choices[0]
assert choice.logprobs is None


@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]

chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=5,
temperature=0.0,
logprobs=True,
top_logprobs=0)

choice = chat_completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.content is not None
assert len(choice.logprobs.content[0].top_logprobs) <= 1


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]

chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=5,
temperature=0.0,
logprobs=True,
top_logprobs=5)

choice = chat_completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.content is not None
assert len(choice.logprobs.content[0].top_logprobs) <= 6


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -263,13 +433,13 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
"content": "what is 1+1?"
}]

# Default max_logprobs is 5, so this should raise an error
# Default max_logprobs is 20, so this should raise an error
with pytest.raises((openai.BadRequestError, openai.APIError)):
stream = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10,
top_logprobs=21,
stream=True)
async for chunk in stream:
...
Expand All @@ -279,25 +449,9 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10,
top_logprobs=30,
stream=False)

with pytest.raises((openai.BadRequestError, openai.APIError)):
stream = await client.completions.create(model=model_name,
prompt="Test",
max_tokens=10,
logprobs=10,
stream=True)
async for chunk in stream:
...

with pytest.raises(openai.BadRequestError):
await client.completions.create(model=model_name,
prompt="Test",
max_tokens=10,
logprobs=10,
stream=False)

# the server should still work afterwards
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
Expand Down Expand Up @@ -744,13 +898,12 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
top_logprobs=5,
extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))
top_logprobs = chat_completion.choices[0].logprobs.top_logprobs
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs

# -9999.0 is the minimum logprob returned by OpenAI
assert all(
isinstance(logprob, float) and logprob >= -9999.0
for token_dict in top_logprobs
for token, logprob in token_dict.items())
isinstance(token.logprob, float) and token.logprob >= -9999.0
for token in top_logprobs)


@pytest.mark.asyncio
Expand Down
50 changes: 43 additions & 7 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,19 @@ def check_guided_decoding_count(cls, data):
"('guided_json', 'guided_regex' or 'guided_choice').")
return data

@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "top_logprobs" in data and data["top_logprobs"] is not None:
if "logprobs" not in data or data["logprobs"] is False:
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
elif not 0 <= data["top_logprobs"] <= 20:
raise ValueError(
"`top_logprobs` must be a value in the interval [0, 20].")
return data


class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
Expand Down Expand Up @@ -396,6 +409,15 @@ def check_guided_decoding_count(cls, data):
"('guided_json', 'guided_regex' or 'guided_choice').")
return data

@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "logprobs" in data and data[
"logprobs"] is not None and not 0 <= data["logprobs"] <= 5:
raise ValueError(("if passed, `logprobs` must be a value",
" in the interval [0, 5]."))
return data


class EmbeddingRequest(BaseModel):
# Ordered by official OpenAI API documentation
Expand All @@ -415,7 +437,7 @@ def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)


class LogProbs(OpenAIBaseModel):
class CompletionLogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
Expand All @@ -425,7 +447,7 @@ class LogProbs(OpenAIBaseModel):
class CompletionResponseChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
logprobs: Optional[CompletionLogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
Expand All @@ -448,7 +470,7 @@ class CompletionResponse(OpenAIBaseModel):
class CompletionResponseStreamChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
logprobs: Optional[CompletionLogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
Expand Down Expand Up @@ -488,11 +510,25 @@ class ChatMessage(OpenAIBaseModel):
content: str


class ChatCompletionLogProb(OpenAIBaseModel):
token: str
logprob: float = -9999.0
bytes: Optional[List[int]] = None


class ChatCompletionLogProbsContent(ChatCompletionLogProb):
top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)


class ChatCompletionLogProbs(OpenAIBaseModel):
content: Optional[List[ChatCompletionLogProbsContent]] = None


class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
stop_reason: Optional[Union[int, str]] = None


Expand All @@ -513,8 +549,8 @@ class DeltaMessage(OpenAIBaseModel):
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
stop_reason: Optional[Union[int, str]] = None


Expand Down
Loading

0 comments on commit 110124d

Please sign in to comment.