Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add vlm api, api sever and api config #1004

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions configs/api_examples/eval_api_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from mmengine.config import read_base
from opencompass.partitioners import NaivePartitioner
from opencompass.runners.local_api import LocalAPIRunner
from opencompass.tasks import OpenICLInferTask
from opencompass.models import VLLM_API

with read_base():
from ..summarizers.medium import summarizer
from ..datasets.mmlu.mmlu_ppl import mmlu_datasets


datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])

models = [
dict(
type= VLLM_API,
abbr='Qwen-7b-api',
path = '',
url="http://0.0.0.0:60/generate",
max_seq_len=2048,
batch_size=1000,
generation_kwargs = {
'temperature': 0.8,
'max_out_len': 1,
'prompt_logprobs': 0,

},),
]

1 change: 1 addition & 0 deletions opencompass/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@
from .xunfei_api import XunFei # noqa: F401
from .zhipuai_api import ZhiPuAI # noqa: F401
from .zhipuai_v2_api import ZhiPuV2AI # noqa: F401
from .vllm_api import VLLM_API
192 changes: 192 additions & 0 deletions opencompass/models/vllm_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union


import requests

from opencompass.utils.prompt import PromptList

from .base_api import BaseAPIModel
import torch
import numpy as np

PromptType = Union[PromptList, str]


class VLLM_API(BaseAPIModel):

def __init__(self,
path: str,
# key: str,
# secretkey: str,
url,
query_per_second: int = 2,
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
retry: int = 2,
generation_kwargs: Dict = {
'temperature': 0.8,
'prompt_logprobs': 0,
}):
super().__init__(path=path,
max_seq_len=max_seq_len,
query_per_second=query_per_second,
meta_template=meta_template,
retry=retry,
generation_kwargs=generation_kwargs)

self.url = url
self.generation_kwargs = generation_kwargs

def post_http_request(self,
prompt: str,
api_url: str,
n: int = 1,
stream: bool = False,
max_out_len: int = 0):
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
"n": self.generation_kwargs['n'],
"use_beam_search": self.generation_kwargs['use_beam_search'],
"temperature": self.generation_kwargs['temperature'],
"stream": stream,
"prompt_logprobs": self.generation_kwargs['prompt_logprobs'],
"max_tokens": self.generation_kwargs['max_out_len'],
}
response = requests.post(api_url, headers=headers, json=pload,
stream=True)
return response

def generate(
self,
inputs: List[str or PromptList],
max_out_len: int = 512,
) -> List[str]:

with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._generate, inputs,
[max_out_len] * len(inputs)))
self.flush()
return results

def _generate(
self,
input: str or PromptList,
max_out_len: int = 512,
) -> str:
"""Generate results given an input.

Args:
inputs (str or PromptList): A string or PromptDict.
The PromptDict should be organized in OpenCompass'
API format.
max_out_len (int): The maximum length of the output.

Returns:
str: The generated string.
"""
max_num_retries = 0
while max_num_retries < self.retry:
self.acquire()
try:
raw_response = self.post_http_request(prompt=input, api_url=self.url)
response = raw_response.json()
except Exception as err:
print('Request Error:{}'.format(err))
time.sleep(3)
continue

self.release()

if response is None:
print('Connection error, reconnect.')
# if connect error, frequent requests will casuse
# continuous unstable network, therefore wait here
# to slow down the request
self.wait()
continue
if raw_response.status_code == 200:
try:
msg = response['output']['outputs'][0]['text']
return msg
except KeyError:
print(response)
self.logger.error(str(response['error_code']))
if response['error_code'] == 336007:
# exceed max length
return ''

time.sleep(1)
continue

print(response)
max_num_retries += 1

raise RuntimeError(response['error_msg'])



def get_ppl(self,inputs: List[str],mask_length: Optional[List[int]] = None) -> List[float]:


with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._generate_ppl, inputs))
self.flush()
results = np.array(results)
return results


def _generate_ppl(
self,
input: str or PromptList,
# max_out_len: int = 512,
):

max_num_retries = 0
while max_num_retries < self.retry:
self.acquire()
try:
raw_response = self.post_http_request(prompt=input, api_url=self.url)
response = raw_response.json()
except Exception as err:
print('Request Error:{}'.format(err))
time.sleep(3)
continue

self.release()

if response is None:
print('Connection error, reconnect.')
# if connect error, frequent requests will casuse
# continuous unstable network, therefore wait here
# to slow down the request
self.wait()
continue
if raw_response.status_code == 200:
try:
outputs_prob = response['output']['prompt_logprobs'][1:]
prompt_token_ids = response['output']['prompt_token_ids'][1:]
outputs_prob_list = [outputs_prob[i][str(prompt_token_ids[i])]['logprob'] for i in range(len(outputs_prob))]
outputs_prob_list = torch.tensor(outputs_prob_list)
loss = -1 * outputs_prob_list.sum(-1).cpu().detach().numpy() / len(prompt_token_ids)

return loss
except KeyError:
print(response)
self.logger.error(str(response['error_code']))
if response['error_code'] == 336007:
# exceed max length
return ''

time.sleep(1)
continue

print(response)
max_num_retries += 1

raise RuntimeError(response['error_msg'])

129 changes: 129 additions & 0 deletions opencompass/models/vllm_api_sever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
"""

import argparse
import json
import ssl
from typing import AsyncGenerator

from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
import uvicorn

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

TIMEOUT_KEEP_ALIVE = 5 # seconds.
app = FastAPI()
engine = None


@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)


@app.post("/generate")
async def generate(request: Request) -> Response:
"""Generate completion for the request.

The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)

if isinstance(prompt, list):
results_generator_all = []
for t in prompt:
request_id = random_uuid()
results_generator_all.append(engine.generate(t, sampling_params, request_id))
else:
request_id = random_uuid()
results_generator = engine.generate(prompt, sampling_params, request_id)

# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
text_outputs = [
prompt + output.text for output in request_output.outputs
]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8")

if stream:
return StreamingResponse(stream_results())

# Non-streaming case
if isinstance(prompt, list):
final_output_all = []
for results_generator_current in results_generator_all:
final_output = None
async for request_output in results_generator_current:
final_output = request_output

assert final_output is not None
ret = {"output": final_output}
final_output_all.append(ret)
return final_output_all

else:
final_output = None
async for request_output in results_generator:
final_output = request_output

assert final_output is not None
ret = {"output": final_output}
return ret


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument("--ssl-ca-certs",
type=str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()

engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)

app.root_path = args.root_path
uvicorn.run(app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs)
Loading