Skip to content

Commit

Permalink
OpenAI Server refactoring (vllm-project#2360)
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianJoncour committed Jan 17, 2024
1 parent cc11aa4 commit c88318b
Show file tree
Hide file tree
Showing 8 changed files with 954 additions and 643 deletions.
3 changes: 3 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ steps:
- label: Engine Test
command: pytest -v -s engine

- label: Entrypoints Test
command: pytest -v -s entrypoints

- label: Kernels Test
command: pytest -v -s kernels
soft_fail: true
Expand Down
3 changes: 3 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ pytest-asyncio
httpx
einops # required for MPT
flash_attn # required for HuggingFace's llama implementation
openai
requests
ray
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from argparse import Namespace
from dataclasses import dataclass
import os
import pathlib

import pytest
from fastapi.testclient import TestClient

from vllm.entrypoints.openai.api_server import *
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.protocol import ChatCompletionRequest

chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
__file__))).parent.parent / "examples/template_chatml.jinja"
Expand Down Expand Up @@ -48,21 +48,24 @@
'content': 'What is the capital of'
},
]
client = TestClient(app)


@dataclass
class MockTokenizer:
chat_template = None


@dataclass
class MockServingChat:
tokenizer: MockTokenizer


def test_load_chat_template():
# Testing chatml template
mock_args = Namespace(chat_template=chatml_jinja_path)
tokenizer = MockTokenizer()

# Call the function with the mocked args
load_chat_template(mock_args, tokenizer)
mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=chatml_jinja_path)

template_content = tokenizer.chat_template

Expand All @@ -76,11 +79,11 @@ def test_load_chat_template():
def test_no_load_chat_template():
# Testing chatml template
template = "../../examples/does_not_exist"
mock_args = Namespace(chat_template=template)
tokenizer = MockTokenizer()

# Call the function with the mocked args
load_chat_template(mock_args, tokenizer=tokenizer)
mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
template_content = tokenizer.chat_template

# Test assertions
Expand All @@ -97,9 +100,9 @@ async def test_get_gen_prompt(model, template, add_generation_prompt,
expected_output):
# Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model)

mock_args = Namespace(chat_template=template)
load_chat_template(mock_args, tokenizer)
mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)

# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest(
Expand All @@ -115,8 +118,3 @@ async def test_get_gen_prompt(model, template, add_generation_prompt,

# Test assertion
assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}"


def test_health_endpoint():
response = client.get("/health")
assert response.status_code == 200
193 changes: 193 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import time
import subprocess

import sys
import pytest
import requests
import ray # using Ray for overall ease of process management, parallel requests, and debugging.
import openai # use the official client for correctness check

MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here

pytestmark = pytest.mark.asyncio


@ray.remote(num_gpus=1)
class ServerRunner:

def __init__(self, args):
self.proc = subprocess.Popen(
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
stdout=sys.stdout,
stderr=sys.stderr,
)
self._wait_for_server()

def ready(self):
return True

def _wait_for_server(self):
# run health check
start = time.time()
while True:
try:
if requests.get(
"http://localhost:8000/health").status_code == 200:
break
except Exception as err:
if self.proc.poll() is not None:
raise RuntimeError("Server exited unexpectedly.") from err

time.sleep(0.5)
if time.time() - start > MAX_SERVER_START_WAIT_S:
raise RuntimeError(
"Server failed to start in time.") from err

def __del__(self):
if hasattr(self, "proc"):
self.proc.terminate()


@pytest.fixture(scope="session")
def server():
ray.init()
server_runner = ServerRunner.remote([
"--model",
MODEL_NAME,
"--dtype",
"bfloat16", # use half precision for speed and memory savings in CI environment
"--max-model-len",
"8192"
])
ray.get(server_runner.ready.remote())
yield server_runner
ray.shutdown()


@pytest.fixture(scope="session")
def client():
client = openai.AsyncOpenAI(
base_url="http://localhost:8000/v1",
api_key="token-abc123",
)
yield client


async def test_single_completion(server, client: openai.AsyncOpenAI):
completion = await client.completions.create(model=MODEL_NAME,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)


async def test_single_chat_session(server, client: openai.AsyncOpenAI):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]

# test single completion
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
)
assert chat_completion.id is not None
assert chat_completion.choices is not None and len(
chat_completion.choices) == 1
assert chat_completion.choices[0].message is not None
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
messages.append({"role": "assistant", "content": message.content})

# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0


async def test_completion_streaming(server, client: openai.AsyncOpenAI):
prompt = "What is an LLM?"

single_completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=5,
temperature=0.0,
)
single_output = single_completion.choices[0].text
single_usage = single_completion.usage

stream = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
)
chunks = []
async for chunk in stream:
chunks.append(chunk.choices[0].text)
assert chunk.choices[0].finish_reason == "length"
assert chunk.usage == single_usage
assert "".join(chunks) == single_output


async def test_chat_streaming(server, client: openai.AsyncOpenAI):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]

# test single completion
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
temperature=0.0,
)
output = chat_completion.choices[0].message.content
stop_reason = chat_completion.choices[0].finish_reason

# test streaming
stream = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
temperature=0.0,
stream=True,
)
chunks = []
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert delta.role == "assistant"
if delta.content:
chunks.append(delta.content)
assert chunk.choices[0].finish_reason == stop_reason
assert "".join(chunks) == output


if __name__ == "__main__":
pytest.main([__file__])
Loading

0 comments on commit c88318b

Please sign in to comment.