Skip to content

Commit

Permalink
feat(helper): add eta on progressbar for known-length inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Oct 15, 2021
1 parent 7d9e6fd commit c098951
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 12 deletions.
5 changes: 5 additions & 0 deletions jina/clients/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def _get_requests(
# override by the caller-specific kwargs
_kwargs.update(kwargs)

if hasattr(self._inputs, '__len__'):
self._inputs_length = max(1, len(self._inputs) / _kwargs['request_size'])
else:
self._inputs_length = None

if inspect.isasyncgen(self.inputs):
from ..request.asyncio import request_generator

Expand Down
6 changes: 5 additions & 1 deletion jina/clients/base/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ async def _get_results(
stub = jina_pb2_grpc.JinaRPCStub(channel)
self.logger.debug(f'connected to {self.args.host}:{self.args.port}')

cm1 = ProgressBar() if self.show_progress else nullcontext()
cm1 = (
ProgressBar(total_length=self._inputs_length)
if self.show_progress
else nullcontext()
)

with cm1 as p_bar:
async for resp in stub.Call(req_iter):
Expand Down
6 changes: 5 additions & 1 deletion jina/clients/base/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ async def _get_results(

async with AsyncExitStack() as stack:
try:
cm1 = ProgressBar() if self.show_progress else nullcontext()
cm1 = (
ProgressBar(total_length=self._inputs_length)
if self.show_progress
else nullcontext()
)
p_bar = stack.enter_context(cm1)

proto = 'https' if self.args.https else 'http'
Expand Down
6 changes: 5 additions & 1 deletion jina/clients/base/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ async def _get_results(

async with AsyncExitStack() as stack:
try:
cm1 = ProgressBar() if self.show_progress else nullcontext()
cm1 = (
ProgressBar(total_length=self._inputs_length)
if self.show_progress
else nullcontext()
)
p_bar = stack.enter_context(cm1)

proto = 'wss' if self.args.https else 'ws'
Expand Down
23 changes: 15 additions & 8 deletions jina/logging/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from functools import wraps
from typing import Optional, Union, Callable

from .. import __windows__
from jina.enums import ProgressBarStatus

from .logger import JinaLogger
from .. import __windows__
from ..helper import colored, get_readable_size, get_readable_time


Expand Down Expand Up @@ -214,13 +213,15 @@ def __init__(
description: str = 'Working...',
message_on_done: Union[str, Callable[..., str], None] = None,
final_line_feed: bool = True,
total_length: Optional[int] = None,
):
"""
Create the ProgressBar.
:param description: The name of the task, will be displayed in front of the bar.
:param message_on_done: The final message to print when the progress is complete
:param final_line_feed: if False, the line will not get a Line Feed and thus is easily overwritable.
:param total_length: if set, then every :py:meth:`.update` increases the bar by `1/total_length * _bars_on_row`
"""
super().__init__(description, None)
self._bars_on_row = 40
Expand All @@ -229,6 +230,7 @@ def __init__(
self._num_update_called = 0
self._on_done = message_on_done
self._final_line_feed = final_line_feed
self._total_length = total_length
self._stop_event = threading.Event()

def update(
Expand All @@ -248,7 +250,8 @@ def update(
:param status: If set to a value, it will mark the task as complete, can be either "Done" or "Canceled"
:param first_enter: if this method is called by `__enter__`
"""
self._num_update_called += 0 if first_enter else 1
if self._total_length:
progress = progress / self._total_length * self._bars_on_row
self._completed_progress += progress
self._last_rendered_progress = self._completed_progress
elapsed = time.perf_counter() - self.start
Expand All @@ -268,11 +271,15 @@ def update(
else:
bar_color, unfinished_bar_color = 'green', 'green'

speed_str = (
'estimating...'
if first_enter
else f'{self._num_update_called / elapsed:4.1f} step/s'
)
if first_enter:
speed_str = 'estimating...'
elif self._total_length:
_prog = self._num_update_called / self._total_length
speed_str = f'{(_prog * 100):.0f}% ETA: {get_readable_time(seconds=self.now() / (_prog + 1e-6) * (1 - _prog + 1e-6))}'
else:
speed_str = f'{self._num_update_called / elapsed:4.1f} step/s'

self._num_update_called += 0 if first_enter else 1

description_str = description or self.task_name or ''
if status != ProgressBarStatus.WORKING:
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
@pytest.mark.parametrize('details', [None, 'step {}'])
@pytest.mark.parametrize('msg_on_done', [None, 'done!', lambda: 'done!'])
def test_progressbar(total_steps, update_tick, task_name, capsys, details, msg_on_done):
with ProgressBar(description=task_name, message_on_done=msg_on_done) as pb:
with ProgressBar(
description=task_name, message_on_done=msg_on_done, total_length=total_steps
) as pb:
for j in range(total_steps):
pb.update(update_tick, message=details.format(j) if details else None)
time.sleep(0.001)
Expand Down

0 comments on commit c098951

Please sign in to comment.