Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[inference]Re push async dynamic batching" #4905

Merged
merged 1 commit into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions colossalai/inference/dynamic_batching/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class Req:
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""):
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams):
self.request_id = request_id
self.prompt_ids = prompt_ids
self.input_len = len(prompt_ids)
Expand All @@ -14,7 +14,6 @@ def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompt
self.output_metadata_list = []
self.has_generate_finished = False
self.aborted = False
self.prompts = prompts

def to_rpc_obj(self):
return {
Expand All @@ -37,11 +36,7 @@ def stop_sequences_matched(self):
if self.sample_params.stop_sequences is not None:
for stop_token_ids in self.sample_params.stop_sequences:
stop_len = len(stop_token_ids)
if (
stop_len > 0
and len(self.output_ids) >= stop_len
and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len))
):
if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)):
return True
return False

Expand Down Expand Up @@ -107,7 +102,7 @@ def mark_finished_req(self, eos_id):
has_new_finish = True
return has_new_finish

def filter_finished(self) -> List[Req]:
def filter_finished(self)->List[Req]:
"""
Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
"""
Expand All @@ -116,9 +111,9 @@ def filter_finished(self) -> List[Req]:
finished_req = []
for req in self.reqs:
if not req.has_generate_finished:
unfinished_req.append(req)
unfinished_req.append(req)
else:
finished_req.append(req)
finished_req.append(req)
self.reqs = unfinished_req
self.id_to_reqs = {req.request_id: req for req in self.reqs}
return finished_req
Expand Down
139 changes: 63 additions & 76 deletions colossalai/inference/manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import time
from typing import List

from transformers import AutoTokenizer
import asyncio

from .dynamic_batching.infer_batch import InferBatch
from .dynamic_batching.io_struct import Batch, Req
Expand All @@ -10,17 +9,16 @@
from .dynamic_batching.stats import Stats
from .tensor_parallel import TPInferEngine

from transformers import AutoTokenizer
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"


class DynamicBatchManager:
def __init__(
self,
tp_engine: TPInferEngine,
max_total_token_num,
batch_max_tokens,
eos_id,
model,
log_stats=True,
log_stats_interval=10,
running_batch: Batch = None,
Expand All @@ -32,7 +30,6 @@ def __init__(
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
eos_id : The end token of a seq
model: the model weight dir path, the app will load config, weights and tokenizer from this dir
log_stats : whether to log stats
log_stats_interval : log stats interval
running_batch : running batch
Expand All @@ -48,32 +45,32 @@ def __init__(
self.eos_id = eos_id
self.has_wait_tokens = 0
self.max_wait_tokens = 10
self.model = model


self.stats_tool = Stats(log_stats, log_stats_interval)
self.mem_usage_interval = log_stats_interval * 2
self._set_tokenizer(tokenizer_name=self.model)

async def add_req(self, request_id, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""):
def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str):
"""
Add new request to req queue, during initialization all requests are held in waiting list.
"""
req = Req(request_id, prompt_ids, sampling_params, prompts)
req = Req(request_id, prompt_ids, sampling_params)
self.req_queue.append(req)
return

async def add_input(self, request_id, sampling_params, prompts):
def add_input(self, request_id, sampling_params, input_ids):
"""
Encode and Add new input to req queue. support one sequence input for now.
"""
prompt_ids = self.tokenizer.encode(prompts)
prompt_ids = self.tokenizer.encode(input_ids)
prompt_len = len(prompt_ids)
if prompt_len > self.engine.max_input_len:
raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}")
raise ValueError(
f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}"
)
sampling_params.stop_sentences_to_token_ids(self.tokenizer)
self.add_req(request_id, prompt_ids, sampling_params, prompts)
self.add_req(prompt_ids, sampling_params, request_id)
return

def abort(self, request_id):
if self.running_batch is not None:
for req in self.running_batch.reqs:
Expand All @@ -91,15 +88,10 @@ async def loop_for_fwd(self):
The main loop for a dynamic batching process.
"""
counter_count = 0
# self.running_batch is not None or self.req_queue.waiting_req_list
#self.running_batch is not None or self.req_queue.waiting_req_list
while True:
if self.running_batch is not None or self.req_queue.waiting_req_list:
async for result in self._step():
yield result
else:
# need to wait for new requests
await asyncio.sleep(0.1)
continue
async for item in self._step():
yield item
counter_count += 1
if self.running_batch is not None:
if counter_count % self.mem_usage_interval == 0:
Expand All @@ -111,33 +103,30 @@ async def loop_for_fwd(self):
)
self.stats_tool.print_stats()

def _set_tokenizer(
self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast: bool = True
):
if self.running_batch is None:
time.sleep(0.1) # 10ms

def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,):
if tokenizer is not None:
self.tokenizer = tokenizer
self.tokenizer = tokenizer
else:
if "llama" in tokenizer_name.lower() and use_fast == True:
print(
"For some LLaMA-based models, initializing the fast tokenizer may "
"take a long time. To eliminate the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer. This is done automatically in Colossalai."
)

tokenizer_name = _FAST_LLAMA_TOKENIZER

try:
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
)
except TypeError:
"For some LLaMA-based models, initializing the fast tokenizer may "
"take a long time. To eliminate the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer. This is done automatically in Colossalai.")

tokenizer_name = _FAST_LLAMA_TOKENIZER

try:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)
except TypeError as e:
use_fast = False
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)


async def _step(self):
def _step(self):
"""
Logic for handling requests
"""
Expand All @@ -147,36 +136,33 @@ async def _step(self):
if new_batch is not None:
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
async for item in self._prefill_batch(self.running_batch):
yield item
yield from self._prefill_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens = 0
return

if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch)
self._decode_batch(self.running_batch)
yield from self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1
return
else:
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_mini_batch is not None:
self.stats_tool.count_prompt_tokens(new_mini_batch)
async for item in self._prefill_batch(new_mini_batch):
yield item
yield from self._prefill_batch(new_mini_batch)
if not new_mini_batch.is_clear():
self._merge_batch(self.running_batch, new_mini_batch)
self.running_batch.merge(new_mini_batch)
self.has_wait_tokens = 0

else:
self.stats_tool.count_output_tokens(self.running_batch)
async for item in self._decode_batch(self.running_batch):
yield item
yield from self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1

return

def _init_batch(self, batch: Batch, dtype="fp16"):
Expand All @@ -201,7 +187,7 @@ def _init_batch(self, batch: Batch, dtype="fp16"):
)
self.engine.cache[batch_id] = batch_data

async def _prefill_batch(self, batch):
def _prefill_batch(self, batch):
"""
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
"""
Expand All @@ -212,20 +198,19 @@ async def _prefill_batch(self, batch):
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id)
async for item in self._handle_finish_req(batch, has_new_finished_req):
yield item
yield from self._handle_finish_req(batch, has_new_finished_req)

# delete finished reqs

async def _decode_batch(self, batch: Batch):
def _decode_batch(self, batch: Batch):
"""
Decoding process
"""
ans = self.engine._decode_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id)
async for item in self._handle_finish_req(batch, has_new_finished_req):
yield item
yield from self._handle_finish_req(batch, has_new_finished_req)

def _filter_batch(self, batch: Batch):
batch_id = batch.batch_id
Expand Down Expand Up @@ -255,15 +240,15 @@ def _remove_batch(self, batch):
batch.free_self()
del batch

async def _handle_finish_req(self, batch: Batch, has_new_finished_req):
def _handle_finish_req(self, batch: Batch, has_new_finished_req):
if has_new_finished_req:
finished_reqs = batch.filter_finished()
finished_reqs=batch.filter_finished()
if batch.is_clear():
self._remove_batch(batch)
else:
self._filter_batch(batch)
async for item in self._output_process(finished_reqs):
yield item
yield from self._output_process(finished_reqs)


def _filter_runing_batch(self):
if self.running_batch is not None and self.running_batch.is_clear():
Expand All @@ -282,24 +267,18 @@ async def _output_process(self, finished_reqs: List[Req]):
"""
for req in finished_reqs:
output = self.tokenizer.decode(req.output_ids)
yield req.prompts + output
yield output, req.request_id, req.output_metadata_list

def clean_up(self):
# this logic should be implemented in the future.
pass

async def generate(self, request_id, prompt_id, sampling_params):
async def generate(self,request_id,prompt_id,sampling_params):
"""
Generate the output of a request.
"""

await self.add_input(request_id, prompt_id, sampling_params)


async def process_data(dbm):
async for data in dbm.loop_for_fwd():
print(data)

self.add_input(request_id,prompt_id,sampling_params)


def start_dynamic_batching(args, tp_engine, waiting_req_list):
try:
Expand All @@ -308,13 +287,21 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list):
max_total_token_num=args.max_total_token_num,
batch_max_tokens=args.batch_max_tokens,
eos_id=args.eos_id,
model=args.model,
log_stats=not args.disable_log_stats,
log_stats_interval=args.log_stats_interval,
waiting_req_list=waiting_req_list,
)

except Exception:
raise RuntimeError("Failed to start dynamic batching")
batch_manager.clean_up()
raise

batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__)
prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world"))

asyncio.run(prod_task)

for item in batch_manager.loop_for_fwd():
print(item)

return batch_manager
33 changes: 33 additions & 0 deletions colossalai/inference/test_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import asyncio

shared_list = []

async def producer():
for i in range(5):
await asyncio.sleep(1) # 模拟异步获取数据的操作
shared_list.append(i)
print(f"Produced {i}")

async def consumer():
last_index = 0
while True:
await asyncio.sleep(0.5) # 为了不使循环过于紧凑,增加了小的延迟
if last_index < len(shared_list):
item = shared_list[last_index]
print(f"Consumed {item}")
yield item
last_index += 1

async def main():
# 创建生产者和消费者任务
prod_task = asyncio.create_task(producer())

# 等待生产者任务完成
await prod_task

async for data in consumer():
print(data)
# 为了示例的目的,我们只等待一段时间,然后停止消费者
await asyncio.sleep(5)

asyncio.run(main())
Loading
Loading