Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(batching): break batch into smaller partitions if it's larger than max batch size #4752

Merged
merged 2 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/_bentoml_impl/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def fallback() -> t.NoReturn:
get_batch_size=functools.partial(
AutoContainer.get_batch_size, batch_dim=method.batch_dim[0]
),
batch_dim=method.batch_dim,
)

@functools.cached_property
Expand Down
175 changes: 101 additions & 74 deletions src/bentoml/_internal/marshal/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import functools
import logging
import time
import traceback
import typing as t
from functools import cached_property

Expand Down Expand Up @@ -117,6 +116,10 @@ class CorkDispatcher(t.Generic[T_IN, T_OUT]):
The wrapped function should be an async function.
"""

callback: t.Callable[[t.Sequence[T_IN]], t.Coroutine[None, None, t.Sequence[T_OUT]]]
# the interval between each poll
TICKET_INTERVAL = 0.001

def __init__(
self,
max_latency_in_ms: int,
Expand All @@ -125,13 +128,16 @@ def __init__(
shared_sema: t.Optional[NonBlockSema] = None,
fallback: t.Callable[[], T_OUT],
get_batch_size: t.Callable[[T_IN], int] = lambda x: x.sample.batch_size,
batch_dim: tuple[int, int] = (0, 0),
) -> None:
...
"""
params:
* max_latency_in_ms: max_latency_in_ms for inbound tasks in milliseconds
* max_batch_size: max batch size of inbound tasks
* shared_sema: semaphore to limit concurrent outbound tasks
* get_batch_size: callable to get batch size from inputs
* batch_dim: tuple of (input_dim, output_dim) of batch
* fallback: callable to return fallback result
raises:
* all possible exceptions the decorated function has
Expand All @@ -140,13 +146,14 @@ def __init__(
self.fallback = fallback
self.optimizer = Optimizer(self.max_latency)
self.max_batch_size = int(max_batch_size)
self.tick_interval = 0.001

self._controller = None
self._queue: collections.deque[Job[T_IN, T_OUT]] = (
collections.deque()
) # TODO(bojiang): maxlen
self.get_batch_size = get_batch_size
self.batch_dim = batch_dim
# at most 1 batch can be processed at the same time
self._sema = shared_sema if shared_sema else NonBlockSema(1)

def shutdown(self) -> None:
Expand Down Expand Up @@ -201,8 +208,8 @@ async def train_optimizer(
if training_batch_size > 1:
wait = min(
self.max_latency * 0.95,
(training_batch_size * 2 + 1)
* (self.optimizer.o_a + self.optimizer.o_b),
# wait for at most approximated time to process 2n + 1 requests
(training_batch_size * 2 + 1) * self.optimizer.o_a + self.optimizer.o_b,
)

req_count = 0
Expand All @@ -213,6 +220,7 @@ async def train_optimizer(

n = len(self._queue)
now = time.time()
# the wait time of the first request
w0 = now - self._queue[0].enqueue_time

# only cancel requests if there are more than enough for training
Expand All @@ -228,45 +236,19 @@ async def train_optimizer(
n < training_batch_size
and (training_batch_size * a + b) + w0 <= wait
):
await asyncio.sleep(self.tick_interval)
await asyncio.sleep(self.TICKET_INTERVAL)
continue
if self._sema.is_locked():
await asyncio.sleep(self.tick_interval)
await asyncio.sleep(self.TICKET_INTERVAL)
continue

if self.max_batch_size == -1: # batching is disabled
n_call_out = 1
batch_size = self.get_batch_size(self._queue[0].data)
else:
n_call_out = 0
batch_size = 0
try:
for input_info in self._queue:
if (
batch_size + self.get_batch_size(input_info.data)
< self.max_batch_size
):
n_call_out += 1
batch_size += self.get_batch_size(input_info.data)
else:
break
except Exception as e:
n_call_out = min(n, self.max_batch_size)
logger.error(
"error in batch-size aware batching, falling back to regular batching method",
exc_info=e,
)

req_count += 1
# call
self._sema.acquire()
inputs_info = tuple(self._queue.popleft() for _ in range(n_call_out))
for info in inputs_info:
# fake wait as 0 for training requests
info.enqueue_time = now
self._loop.create_task(self.outbound_call(inputs_info))
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)
inputs_info = tuple(self._get_inputs())
self._loop.create_task(self.outbound_call(inputs_info, training=True))
except Exception:
logger.exception("Error in training optimizer")

async def controller(self) -> None:
"""
Expand Down Expand Up @@ -298,19 +280,21 @@ async def controller(self) -> None:
"BentoML has detected that a service has a max latency that is likely too low for serving. If many 503 errors are encountered, try raising the 'runner.max_latency' in your BentoML configuration YAML file."
)
logger.debug("Dispatcher optimizer training complete.")
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)
except Exception: # pylint: disable=broad-except
logger.exception("Error training optimizer")

while True:
try:
async with self._wake_event: # block until there's any request in queue
await self._wake_event.wait_for(self._queue.__len__)

n = len(self._queue)
dt = self.tick_interval
dt = self.TICKET_INTERVAL
decay = 0.95 # the decay rate of wait time
now = time.time()
# the wait time of the first request
w0 = now - self._queue[0].enqueue_time
# the wait time of the last request
wn = now - self._queue[-1].enqueue_time
a = self.optimizer.o_a
b = self.optimizer.o_b
Expand All @@ -325,50 +309,22 @@ async def controller(self) -> None:
if n == 1 and w0 >= self.max_latency:
self._queue.popleft().future.cancel()
continue
await asyncio.sleep(self.tick_interval)
await asyncio.sleep(self.TICKET_INTERVAL)
continue
if (
n < self.max_batch_size
and n * (wn + dt + (a or 0)) <= self.optimizer.wait * decay
):
await asyncio.sleep(self.tick_interval)
await asyncio.sleep(self.TICKET_INTERVAL)
continue

if self.max_batch_size == -1: # batching is disabled
n_call_out = 1
batch_size = self.get_batch_size(self._queue[0].data)
else:
n_call_out = 0
batch_size = 0
try:
for input_info in self._queue:
if (
batch_size + self.get_batch_size(input_info.data)
< self.max_batch_size
):
n_call_out += 1
batch_size += self.get_batch_size(input_info.data)
else:
break
except Exception as e:
n_call_out = min(n, self.max_batch_size)
logger.error(
"error in batch-size aware batching, falling back to regular batching method",
exc_info=e,
)
# call
self._sema.acquire()
inputs_info = tuple(self._queue.popleft() for _ in range(n_call_out))
inputs_info = tuple(self._get_inputs())
self._loop.create_task(self.outbound_call(inputs_info))
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)
except Exception:
logger.exception("Error processing batch requests")

async def inbound_call(self, data: T_IN) -> T_OUT | Exception:
if self.max_batch_size > 0 and self.get_batch_size(data) > self.max_batch_size:
raise RuntimeError(
f"batch of size {self.get_batch_size(data)} exceeds configured max batch size of {self.max_batch_size}."
)

now = time.time()
future: asyncio.Future[T_OUT | Exception] = self._loop.create_future()
input_info = Job(now, data, future)
Expand All @@ -377,7 +333,9 @@ async def inbound_call(self, data: T_IN) -> T_OUT | Exception:
self._wake_event.notify_all()
return await future

async def outbound_call(self, inputs_info: tuple[Job[T_IN, T_OUT], ...]):
async def outbound_call(
self, inputs_info: tuple[Job[T_IN, T_OUT], ...], training: bool = False
):
_time_start = time.time()
_done = False
batch_size = len(inputs_info)
Expand All @@ -394,7 +352,8 @@ async def outbound_call(self, inputs_info: tuple[Job[T_IN, T_OUT], ...]):
_done = True
self.optimizer.log_outbound(
n=len(inputs_info),
wait=_time_start - inputs_info[-1].enqueue_time,
# fake wait as 0 for training requests
wait=_time_start - inputs_info[0].enqueue_time if not training else 0,
duration=time.time() - _time_start,
)
except Exception as e: # pylint: disable=broad-except
Expand All @@ -410,3 +369,71 @@ async def outbound_call(self, inputs_info: tuple[Job[T_IN, T_OUT], ...]):
if not fut.done():
fut.cancel()
self._sema.release()

def _get_inputs(
self, num_batches: int | None = None
) -> t.Iterable[Job[T_IN, T_OUT]]:
if num_batches is None:
num_batches = self.max_batch_size
batch_size = 0
while len(self._queue) > 0 and batch_size < num_batches:
try:
next_batch_size = self.get_batch_size(self._queue[0].data)
except Exception:
logger.exception(
"error in batch-size aware batching, falling back to regular batching method",
)
next_batch_size = 1
if batch_size + next_batch_size <= num_batches:
batch_size += next_batch_size
yield self._queue.popleft()
else:
job = self._queue.popleft()
chunk_size = num_batches - batch_size # chunk_size < next_batch_size
first_job, second_job = self.split_job(
job, [0, chunk_size, next_batch_size]
)
batch_size += chunk_size
self._queue.appendleft(second_job)
yield first_job

def split_job(
self, job: Job[T_IN, T_OUT], split_indexes: list[int]
) -> tuple[Job[T_IN, T_OUT], Job[T_IN, T_OUT]]:
"""Split the job into two child jobs and process them separately."""
from ..runner.container import AutoContainer

logger.debug(
"Splitting batch into two child batches with indexes: %s", split_indexes
)
split_batches = AutoContainer.batch_to_batches(
job.data, split_indexes, self.batch_dim[0]
)
assert len(split_batches) == 2
first_child = attr.evolve(
job, data=split_batches[0], future=self._loop.create_future()
)
second_child = attr.evolve(
job, data=split_batches[1], future=self._loop.create_future()
)

def child_done_callback(fut: asyncio.Future[T_OUT | Exception]):
if fut.cancelled():
job.future.cancel()
return
if job.future.done():
return
result = fut.result()
if isinstance(result, Exception):
job.future.set_result(result)
return
if first_child.future.done() and second_child.future.done():
result, _ = AutoContainer.batches_to_batch(
[first_child.future.result(), second_child.future.result()],
self.batch_dim[1],
)
job.future.set_result(result)

first_child.future.add_done_callback(child_done_callback)
second_child.future.add_done_callback(child_done_callback)
return first_child, second_child
Loading
Loading