Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Fix bug for feat/online-server, updating docstrings #5598

Merged
merged 12 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
"""
import dataclasses
import logging
from dataclasses import dataclass
from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Union

import torch
Expand Down Expand Up @@ -215,7 +214,7 @@ def to_generation_config(self, model_config) -> GenerationConfig:
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
attrs = [attr.name for attr in fields(cls)]
inference_config_args = {}
for attr in attrs:
if attr in config_dict:
Expand Down
52 changes: 36 additions & 16 deletions colossalai/inference/core/async_engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
from functools import partial
from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type

from colossalai.inference.core.engine import InferenceEngine

Expand All @@ -10,7 +10,7 @@
logger = logging.getLogger("colossalai-inference")


def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None:
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None:
msg = "Task finished unexpectedly. This should never happen! "
try:
try:
Expand All @@ -26,8 +26,14 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac


class RequstStream:
"""A stream of Output for a request that can be
iterated over asynchronously."""
"""
A stream of Output for a request that can be iterated over asynchronously.
Attributes: 1.request_id: The id of the request.
2._future: A future that will be set when the request is finished.
Methods: set_result and get_result, results will be set when finished, for once, and
the `self.future` will be set to done.

"""

def __init__(self, request_id: int) -> None:
self.request_id = request_id
Expand All @@ -51,6 +57,10 @@ def finished(self) -> bool:
class Tracer:
"""
Recording new requests and finished requests.
Attributes: 1._request_streams: We create one stream for each request to trace the output.
2._finished_requests: A queue to store the finished requests.
3._new_requests: New requests will be stored in this queue first, before sending them to the engine.
4.new_requests_event: An event to notify the engine that there are new requests.
"""

def __init__(self) -> None:
Expand Down Expand Up @@ -93,8 +103,8 @@ def add_request(self, request_id: int, **engine_add_request_kwargs) -> RequstStr
raise KeyError(f"Request {request_id} already exists.")

stream = RequstStream(request_id)
logger.info(f"Added request {request_id}.")
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))

self.new_requests_event.set()

return stream
Expand All @@ -108,6 +118,7 @@ def abort_request(self, request_id: int, *, verbose: bool = False) -> None:

if request_id not in self._request_streams or self._request_streams[request_id].finished:
# The request has already finished or been aborted.
# The requests in new_requests will be aborted when try to get them(if marked aborted)
return

self._request_streams[request_id].set_result(None)
Expand All @@ -117,9 +128,18 @@ def get_new_requests(self):
Get new requests from http server.
"""
new_requests: List[Dict] = []
finished_requests: Set[int] = set()

while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
finished_requests.add(request_id)

while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if new_request["request_id"] in finished_requests:
# The request has been aborted.
stream.set_result(None)
continue
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)

Expand All @@ -133,7 +153,8 @@ async def wait_for_new_requests(self):

class _AsyncInferenceEngine(InferenceEngine):
"""
Async methods for Inference Engine.
Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for
Methods: 1. async_step: The async version of Engine.step()
"""

async def async_step(self) -> List[str]:
Expand Down Expand Up @@ -161,22 +182,23 @@ async def async_step(self) -> List[str]:
if self.inference_config.pad_input:
logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits)
# Return: List[Sequence]

finished_sequences = self.request_handler.update()
for sequence in finished_sequences:
sequence.output = self.tokenizer.decode(sequence.output_token_id)

return finished_sequences, self.request_handler.current_requests_in_batch() > 0
return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0


class AsyncInferenceEngine:
"""An asynchronous wrapper for LLMEngine.
"""An asynchronous wrapper for the InferenceEngine class.

This class is used to wrap the InferenceEngine class to make it asynchronous.
It uses asyncio to create a background loop that keeps processing incoming
requests. The LLMEngine is kicked by the generate method when there are
requests in the waiting queue. The generate method yields the outputs
from the InferenceEngine to the caller.
requests. Note that this class does not hold model directly, when incoming a new
request, it first called `add_request` and the Tracer will record the request, putting
it to the background `InferenceEngine`(done in background loop) to process. You can
consider this engine as an interface.
"""

_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
Expand Down Expand Up @@ -253,7 +275,7 @@ async def add_request(
prompt_token_ids: Optional[List[int]] = None,
) -> RequstStream:
"""
Add a request to the background tracker(waitting queue), start the background loop if needed.
Add a request to the background tracker(waiting queue), start the background loop if needed.
"""
if not self.background_loop_status:
if self.start_engine_loop:
Expand All @@ -276,14 +298,12 @@ async def generate(
"""
Generate output from a request. It receives the request from http server, adds it into the
waitting queue of Async Engine and streams the output sequence.

"""
try:
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
return await stream.get_result()

except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the
# request.
# If there is an exception or coroutine is cancelled, abort the request.
self._abort(request_id)
raise e
16 changes: 7 additions & 9 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,19 +450,15 @@ def generate(
List[str]: Inference result returned by one generation.
"""
with torch.inference_mode():
if generation_config is not None:
self.generation_config = generation_config

if prompts is not None or prompts_token_ids is not None:
if isinstance(prompts, str) and isinstance(request_ids, int):
prompts = [prompts]
request_ids = [request_ids]
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)

output_seqs_list = []
total_tokens_list = []

# intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None:
self.generation_config = generation_config

if self.use_spec_dec:
assert self.drafter is not None, "Drafter Model is not initialized."
Expand Down Expand Up @@ -527,6 +523,9 @@ def add_request(

block_size = self.inference_config.block_size

if request_ids is not None and not isinstance(request_ids, list):
request_ids = [request_ids]

if prompts is not None and not isinstance(prompts, list):
prompts = [prompts]

Expand All @@ -536,9 +535,10 @@ def add_request(
"input_ids"
]

# list of torch Tensor
if isinstance(prompts_token_ids, list):
if isinstance(prompts_token_ids[0], torch.Tensor):
prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids]
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
prompts_token_ids = prompts_token_ids.tolist()
else:
Expand Down Expand Up @@ -644,8 +644,6 @@ def step(self) -> List[str]:
logits = logits[:, -1, :]
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
self.request_handler.append_next_tokens(next_tokens)

self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update()

return finished_sequences
2 changes: 1 addition & 1 deletion colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def update_batch_finished(self, batch: BatchBucket, generation_config: Generatio
def check_unfinished_seqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty()

def current_requests_in_batch(self) -> int:
def total_requests_in_batch_bucket(self) -> int:
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size

def search_tokens(self, generation_config: GenerationConfig, logits):
Expand Down
40 changes: 6 additions & 34 deletions colossalai/inference/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
Usage: (for local user)
- First, Lauch an API locally. `python3 -m colossalai.inference.server.api_server --model path of your llama2 model`
- Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api
- For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/v1/completion \
- For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/completion \
-H 'Content-Type: application/json' \
-d '{"prompt":"hello, who are you? ","stream":"False"}'`
Version: V1.0
"""

import argparse
Expand Down Expand Up @@ -36,7 +37,8 @@
app = FastAPI()


@app.get("/v0/models")
# NOTE: (CjhHa1) models are still under development, need to be updated
@app.get("/models")
def get_available_models() -> Response:
return JSONResponse(supported_models_dict)

Expand Down Expand Up @@ -81,7 +83,7 @@ def stream_results():
return JSONResponse(ret)


@app.post("/v1/completion")
@app.post("/completion")
async def create_completion(request: Request):
request_dict = await request.json()
stream = request_dict.pop("stream", "false").lower()
Expand All @@ -95,7 +97,7 @@ async def create_completion(request: Request):
return JSONResponse(content=ret)


@app.post("/v1/chat")
@app.post("/chat")
async def create_chat(request: Request):
request_dict = await request.json()

Expand Down Expand Up @@ -127,14 +129,6 @@ def add_engine_config(parser):
help="model context length. If unspecified, " "will be automatically derived from the model.",
)
# Parallel arguments
parser.add_argument(
"--worker-use-ray",
action="store_true",
help="use Ray for distributed serving, will be " "automatically set when using more than 1 GPU",
)

parser.add_argument("--pipeline-parallel-size", "-pp", type=int, default=1, help="number of pipeline stages")

parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")

# KV cache arguments
Expand All @@ -149,28 +143,6 @@ def add_engine_config(parser):
default=None,
help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.",
)

# Quantization settings.
parser.add_argument(
"--quantization",
"-q",
type=str,
choices=["awq", "gptq", "squeezellm", None],
default=None,
help="Method used to quantize the weights. If "
"None, we first check the `quantization_config` "
"attribute in the model config file. If that is "
"None, we assume the model weights are not "
"quantized and use `dtype` to determine the data "
"type of the weights.",
)
parser.add_argument(
"--enforce-eager",
action="store_true",
help="Always use eager-mode PyTorch. If False, "
"will use eager mode and CUDA graph in hybrid "
"for maximal performance and flexibility.",
)
return parser


Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/layer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class VocabParallelEmbedding1D(ParallelModule):
he initializer of weight, defaults to normal initializer.

The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:

::
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
renormalized to have norm max_norm. Note: this will modify weight in-place.
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
Expand Down
10 changes: 5 additions & 5 deletions examples/inference/client/locustfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@ class QuickstartUser(HttpUser):
@tag("online-generation")
@task(5)
def completion(self):
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "False"})
self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "False"})

@tag("online-generation")
@task(5)
def completion_streaming(self):
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "True"})
self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "True"})

@tag("online-chat")
@task(5)
def chat(self):
self.client.post(
"v1/chat",
"/chat",
json={
"converation": [
{"role": "system", "content": "you are a helpful assistant"},
Expand All @@ -32,7 +32,7 @@ def chat(self):
@task(5)
def chat_streaming(self):
self.client.post(
"v1/chat",
"/chat",
json={
"converation": [
{"role": "system", "content": "you are a helpful assistant"},
Expand All @@ -55,4 +55,4 @@ def generate(self):
@tag("online-generation", "offline-generation")
@task
def get_models(self):
self.client.get("/v0/models")
self.client.get("/models")
Empty file added tests/test_infer/__init__.py
Empty file.
Loading