Skip to content

Commit

Permalink
feat(batching): break batch into smaller partitions if it's larger th…
Browse files Browse the repository at this point in the history
…an max batch size

Signed-off-by: Frost Ming <me@frostming.com>
  • Loading branch information
frostming committed May 23, 2024
1 parent 1977c42 commit 0c3fe29
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 221 deletions.
1 change: 1 addition & 0 deletions src/_bentoml_impl/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def fallback() -> t.NoReturn:
get_batch_size=functools.partial(
AutoContainer.get_batch_size, batch_dim=method.batch_dim[0]
),
batch_dim=method.batch_dim,
)

@functools.cached_property
Expand Down
175 changes: 101 additions & 74 deletions src/bentoml/_internal/marshal/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import functools
import logging
import time
import traceback
import typing as t
from functools import cached_property

Expand Down Expand Up @@ -117,6 +116,10 @@ class CorkDispatcher(t.Generic[T_IN, T_OUT]):
The wrapped function should be an async function.
"""

callback: t.Callable[[t.Sequence[T_IN]], t.Coroutine[None, None, t.Sequence[T_OUT]]]
# the interval between each poll
TICKET_INTERVAL = 0.001

def __init__(
self,
max_latency_in_ms: int,
Expand All @@ -125,13 +128,16 @@ def __init__(
shared_sema: t.Optional[NonBlockSema] = None,
fallback: t.Callable[[], T_OUT],
get_batch_size: t.Callable[[T_IN], int] = lambda x: x.sample.batch_size,
batch_dim: tuple[int, int] = (0, 0),
) -> None:
...
"""
params:
* max_latency_in_ms: max_latency_in_ms for inbound tasks in milliseconds
* max_batch_size: max batch size of inbound tasks
* shared_sema: semaphore to limit concurrent outbound tasks
* get_batch_size: callable to get batch size from inputs
* batch_dim: tuple of (input_dim, output_dim) of batch
* fallback: callable to return fallback result
raises:
* all possible exceptions the decorated function has
Expand All @@ -140,13 +146,14 @@ def __init__(
self.fallback = fallback
self.optimizer = Optimizer(self.max_latency)
self.max_batch_size = int(max_batch_size)
self.tick_interval = 0.001

self._controller = None
self._queue: collections.deque[Job[T_IN, T_OUT]] = (
collections.deque()
) # TODO(bojiang): maxlen
self.get_batch_size = get_batch_size
self.batch_dim = batch_dim
# at most 1 batch can be processed at the same time
self._sema = shared_sema if shared_sema else NonBlockSema(1)

def shutdown(self) -> None:
Expand Down Expand Up @@ -201,8 +208,8 @@ async def train_optimizer(
if training_batch_size > 1:
wait = min(
self.max_latency * 0.95,
(training_batch_size * 2 + 1)
* (self.optimizer.o_a + self.optimizer.o_b),
# wait for at most approximated time to process 2n + 1 requests
(training_batch_size * 2 + 1) * self.optimizer.o_a + self.optimizer.o_b,
)

req_count = 0
Expand All @@ -213,6 +220,7 @@ async def train_optimizer(

n = len(self._queue)
now = time.time()
# the wait time of the first request
w0 = now - self._queue[0].enqueue_time

# only cancel requests if there are more than enough for training
Expand All @@ -228,45 +236,19 @@ async def train_optimizer(
n < training_batch_size
and (training_batch_size * a + b) + w0 <= wait
):
await asyncio.sleep(self.tick_interval)
await asyncio.sleep(self.TICKET_INTERVAL)
continue
if self._sema.is_locked():
await asyncio.sleep(self.tick_interval)
await asyncio.sleep(self.TICKET_INTERVAL)
continue

if self.max_batch_size == -1: # batching is disabled
n_call_out = 1
batch_size = self.get_batch_size(self._queue[0].data)
else:
n_call_out = 0
batch_size = 0
try:
for input_info in self._queue:
if (
batch_size + self.get_batch_size(input_info.data)
< self.max_batch_size
):
n_call_out += 1
batch_size += self.get_batch_size(input_info.data)
else:
break
except Exception as e:
n_call_out = min(n, self.max_batch_size)
logger.error(
"error in batch-size aware batching, falling back to regular batching method",
exc_info=e,
)

req_count += 1
# call
self._sema.acquire()
inputs_info = tuple(self._queue.popleft() for _ in range(n_call_out))
for info in inputs_info:
# fake wait as 0 for training requests
info.enqueue_time = now
self._loop.create_task(self.outbound_call(inputs_info))
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)
inputs_info = tuple(self._get_inputs())
self._loop.create_task(self.outbound_call(inputs_info, training=True))
except Exception:
logger.exception("Error in training optimizer")

async def controller(self) -> None:
"""
Expand Down Expand Up @@ -298,19 +280,21 @@ async def controller(self) -> None:
"BentoML has detected that a service has a max latency that is likely too low for serving. If many 503 errors are encountered, try raising the 'runner.max_latency' in your BentoML configuration YAML file."
)
logger.debug("Dispatcher optimizer training complete.")
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)
except Exception: # pylint: disable=broad-except
logger.exception("Error training optimizer")

while True:
try:
async with self._wake_event: # block until there's any request in queue
await self._wake_event.wait_for(self._queue.__len__)

n = len(self._queue)
dt = self.tick_interval
dt = self.TICKET_INTERVAL
decay = 0.95 # the decay rate of wait time
now = time.time()
# the wait time of the first request
w0 = now - self._queue[0].enqueue_time
# the wait time of the last request
wn = now - self._queue[-1].enqueue_time
a = self.optimizer.o_a
b = self.optimizer.o_b
Expand All @@ -325,50 +309,22 @@ async def controller(self) -> None:
if n == 1 and w0 >= self.max_latency:
self._queue.popleft().future.cancel()
continue
await asyncio.sleep(self.tick_interval)
await asyncio.sleep(self.TICKET_INTERVAL)
continue
if (
n < self.max_batch_size
and n * (wn + dt + (a or 0)) <= self.optimizer.wait * decay
):
await asyncio.sleep(self.tick_interval)
await asyncio.sleep(self.TICKET_INTERVAL)
continue

if self.max_batch_size == -1: # batching is disabled
n_call_out = 1
batch_size = self.get_batch_size(self._queue[0].data)
else:
n_call_out = 0
batch_size = 0
try:
for input_info in self._queue:
if (
batch_size + self.get_batch_size(input_info.data)
< self.max_batch_size
):
n_call_out += 1
batch_size += self.get_batch_size(input_info.data)
else:
break
except Exception as e:
n_call_out = min(n, self.max_batch_size)
logger.error(
"error in batch-size aware batching, falling back to regular batching method",
exc_info=e,
)
# call
self._sema.acquire()
inputs_info = tuple(self._queue.popleft() for _ in range(n_call_out))
inputs_info = tuple(self._get_inputs())
self._loop.create_task(self.outbound_call(inputs_info))
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)
except Exception:
logger.exception("Error processing batch requests")

async def inbound_call(self, data: T_IN) -> T_OUT | Exception:
if self.max_batch_size > 0 and self.get_batch_size(data) > self.max_batch_size:
raise RuntimeError(
f"batch of size {self.get_batch_size(data)} exceeds configured max batch size of {self.max_batch_size}."
)

now = time.time()
future: asyncio.Future[T_OUT | Exception] = self._loop.create_future()
input_info = Job(now, data, future)
Expand All @@ -377,7 +333,9 @@ async def inbound_call(self, data: T_IN) -> T_OUT | Exception:
self._wake_event.notify_all()
return await future

async def outbound_call(self, inputs_info: tuple[Job[T_IN, T_OUT], ...]):
async def outbound_call(
self, inputs_info: tuple[Job[T_IN, T_OUT], ...], training: bool = False
):
_time_start = time.time()
_done = False
batch_size = len(inputs_info)
Expand All @@ -394,7 +352,8 @@ async def outbound_call(self, inputs_info: tuple[Job[T_IN, T_OUT], ...]):
_done = True
self.optimizer.log_outbound(
n=len(inputs_info),
wait=_time_start - inputs_info[-1].enqueue_time,
# fake wait as 0 for training requests
wait=_time_start - inputs_info[0].enqueue_time if not training else 0,
duration=time.time() - _time_start,
)
except Exception as e: # pylint: disable=broad-except
Expand All @@ -410,3 +369,71 @@ async def outbound_call(self, inputs_info: tuple[Job[T_IN, T_OUT], ...]):
if not fut.done():
fut.cancel()
self._sema.release()

def _get_inputs(
self, num_batches: int | None = None
) -> t.Iterable[Job[T_IN, T_OUT]]:
if num_batches is None:
num_batches = self.max_batch_size
batch_size = 0
while len(self._queue) > 0 and batch_size < num_batches:
try:
next_batch_size = self.get_batch_size(self._queue[0].data)
except Exception:
logger.exception(
"error in batch-size aware batching, falling back to regular batching method",
)
next_batch_size = 1
if batch_size + next_batch_size <= num_batches:
batch_size += next_batch_size
yield self._queue.popleft()
else:
job = self._queue.popleft()
chunk_size = num_batches - batch_size # chunk_size < next_batch_size
first_job, second_job = self.split_job(
job, [0, chunk_size, next_batch_size]
)
batch_size += chunk_size
self._queue.appendleft(second_job)
yield first_job

def split_job(
self, job: Job[T_IN, T_OUT], split_indexes: list[int]
) -> tuple[Job[T_IN, T_OUT], Job[T_IN, T_OUT]]:
"""Split the job into two child jobs and process them separately."""
from ..runner.container import AutoContainer

logger.debug(
"Splitting batch into two child batches with indexes: %s", split_indexes
)
split_batches = AutoContainer.batch_to_batches(
job.data, split_indexes, self.batch_dim[0]
)
assert len(split_batches) == 2
first_child = attr.evolve(
job, data=split_batches[0], future=self._loop.create_future()
)
second_child = attr.evolve(
job, data=split_batches[1], future=self._loop.create_future()
)

def child_done_callback(fut: asyncio.Future[T_OUT | Exception]):
if fut.cancelled():
job.future.cancel()
return
if job.future.done():
return
result = fut.result()
if isinstance(result, Exception):
job.future.set_result(result)
return
if first_child.future.done() and second_child.future.done():
result, _ = AutoContainer.batches_to_batch(
[first_child.future.result(), second_child.future.result()],
self.batch_dim[1],
)
job.future.set_result(result)

first_child.future.add_done_callback(child_done_callback)
second_child.future.add_done_callback(child_done_callback)
return first_child, second_child
Loading

0 comments on commit 0c3fe29

Please sign in to comment.