Skip to content

Commit

Permalink
Merge pull request #2333 from casperdcl/tqdm
Browse files Browse the repository at this point in the history
change progress bar backend to tqdm
  • Loading branch information
efiop committed Aug 21, 2019
2 parents 4d267ef + 4bd2ce2 commit 5aa9a2f
Show file tree
Hide file tree
Showing 21 changed files with 369 additions and 562 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -23,6 +23,7 @@ innosetup/config.ini
*.exe

.coverage
.coverage.*

*.swp

Expand Down
6 changes: 6 additions & 0 deletions .mailmap
@@ -0,0 +1,6 @@
Paweł Redzyński <pawelredzynski@gmail.com>
Dmitry Petrov <dmitry.petrov@nevesomo.com>
Earl Hathaway <github@earlh.com>
Nabanita Dash <dashnabanita@gmail.com>
Kurian Benoy <kurian.bkk@gmail.com>
Sritanu Chakraborty <sritanu25@gmail.com>
25 changes: 14 additions & 11 deletions dvc/logger.py
Expand Up @@ -3,6 +3,7 @@
from __future__ import unicode_literals

from dvc.utils.compat import str, StringIO
from dvc.progress import Tqdm

import logging
import logging.handlers
Expand Down Expand Up @@ -53,9 +54,6 @@ class ColorFormatter(logging.Formatter):
)

def format(self, record):
if self._is_visible(record):
self._progress_aware()

if record.levelname == "INFO":
return record.msg

Expand Down Expand Up @@ -146,20 +144,25 @@ def _parse_exc(self, exc_info):

return (exception, stack_trace)

def _progress_aware(self):
"""Add a new line if progress bar hasn't finished"""
from dvc.progress import progress

if not progress.is_finished:
progress._print()
progress.clearln()


class LoggerHandler(logging.StreamHandler):
def handleError(self, record):
super(LoggerHandler, self).handleError(record)
raise LoggingException(record)

def emit(self, record):
"""Write to Tqdm's stream so as to not break progressbars"""
try:
msg = self.format(record)
Tqdm.write(
msg, file=self.stream, end=getattr(self, "terminator", "\n")
)
self.flush()
except RecursionError:
raise
except Exception:
self.handleError(record)


def setup(level=logging.INFO):
colorama.init()
Expand Down
8 changes: 3 additions & 5 deletions dvc/output/base.py
Expand Up @@ -277,6 +277,7 @@ def download(self, to):

def checkout(self, force=False, progress_callback=None, tag=None):
if not self.use_cache:
progress_callback(str(self.path_info), self.get_files_number())
return

if tag:
Expand Down Expand Up @@ -313,13 +314,10 @@ def move(self, out):
self.repo.scm.ignore(self.fspath)

def get_files_number(self):
if not self.use_cache or not self.checksum:
if not self.use_cache:
return 0

if self.is_dir_checksum:
return len(self.dir_cache)

return 1
return self.cache.get_files_number(self.checksum)

def unprotect(self):
if self.exists:
Expand Down
222 changes: 81 additions & 141 deletions dvc/progress.py
@@ -1,154 +1,94 @@
"""Manages progress bars for dvc repo."""

from __future__ import print_function
from __future__ import unicode_literals

from dvc.utils.compat import str

import sys
import threading
import logging
from tqdm import tqdm
from copy import deepcopy
from concurrent.futures import ThreadPoolExecutor

CLEARLINE_PATTERN = "\r\x1b[K"


class Progress(object):
class TqdmThreadPoolExecutor(ThreadPoolExecutor):
"""
Simple multi-target progress bar.
Ensure worker progressbars are cleared away properly.
"""

def __init__(self):
self._n_total = 0
self._n_finished = 0
self._lock = threading.Lock()
self._line = None

def set_n_total(self, total):
"""Sets total number of targets."""
self._n_total = total
self._n_finished = 0

@property
def is_finished(self):
"""Returns if all targets have finished."""
return self._n_total == self._n_finished

def clearln(self):
self._print(CLEARLINE_PATTERN, end="")

def _writeln(self, line):
self.clearln()
self._print(line, end="")
sys.stdout.flush()

def reset(self):
with self._lock:
self._n_total = 0
self._n_finished = 0
self._line = None

def refresh(self, line=None):
"""Refreshes progress bar."""
# Just go away if it is locked. Will update next time
if not self._lock.acquire(False):
return

if line is None:
line = self._line

if sys.stdout.isatty() and line is not None:
self._writeln(line)
self._line = line

self._lock.release()

def update_target(self, name, current, total):
"""Updates progress bar for a specified target."""
self.refresh(self._bar(name, current, total))

def finish_target(self, name):
"""Finishes progress bar for a specified target."""
# We have to write a msg about finished target
with self._lock:
pbar = self._bar(name, 100, 100)

if sys.stdout.isatty():
self.clearln()

self._print(pbar)

self._n_finished += 1
self._line = None

def _bar(self, target_name, current, total):
def __enter__(self):
"""
Make a progress bar out of info, which looks like:
(1/2): [########################################] 100% master.zip
Creates a blank initial dummy progress bar if needed so that workers
are forced to create "nested" bars.
"""
bar_len = 30

if total is None:
state = 0
percent = "?% "
else:
total = int(total)
state = int((100 * current) / total) if current < total else 100
percent = str(state) + "% "

if self._n_total > 1:
num = "({}/{}): ".format(self._n_finished + 1, self._n_total)
else:
num = ""
blank_bar = Tqdm(bar_format="Multi-Threaded:", leave=False)
if blank_bar.pos > 0:
# already nested - don't need a placeholder bar
blank_bar.close()
self.bar = blank_bar
return super(TqdmThreadPoolExecutor, self).__enter__()

n_sh = int((state * bar_len) / 100)
n_sp = bar_len - n_sh
pbar = "[" + "#" * n_sh + " " * n_sp + "] "
def __exit__(self, *a, **k):
super(TqdmThreadPoolExecutor, self).__exit__(*a, **k)
self.bar.close()

return num + pbar + percent + target_name

@staticmethod
def _print(*args, **kwargs):
import logging

logger = logging.getLogger(__name__)

if logger.getEffectiveLevel() == logging.CRITICAL:
return

print(*args, **kwargs)

def __enter__(self):
self._lock.acquire(True)
if self._line is not None:
self.clearln()

def __exit__(self, typ, value, tbck):
if self._line is not None:
self.refresh()
self._lock.release()

def __call__(self, seq, name="", total=None):
if total is None:
total = len(seq)

self.update_target(name, 0, total)
for done, item in enumerate(seq, start=1):
yield item
self.update_target(name, done, total)
self.finish_target(name)


class ProgressCallback(object):
def __init__(self, total):
self.total = total
self.current = 0
progress.reset()

def update(self, name, progress_to_add=1):
self.current += progress_to_add
progress.update_target(name, self.current, self.total)

def finish(self, name):
progress.finish_target(name)

class Tqdm(tqdm):
"""
maximum-compatibility tqdm-based progressbars
"""

progress = Progress() # pylint: disable=invalid-name
def __init__(
self,
iterable=None,
disable=None,
bytes=False, # pylint: disable=W0622
desc_truncate=None,
leave=None,
**kwargs
):
"""
bytes : shortcut for
`unit='B', unit_scale=True, unit_divisor=1024, miniters=1`
desc_truncate : like `desc` but will truncate to 10 chars
kwargs : anything accepted by `tqdm.tqdm()`
"""
kwargs = deepcopy(kwargs)
if bytes:
for k, v in dict(
unit="B", unit_scale=True, unit_divisor=1024, miniters=1
).items():
kwargs.setdefault(k, v)
if desc_truncate is not None:
kwargs.setdefault("desc", self.truncate(desc_truncate))
if disable is None:
disable = (
logging.getLogger(__name__).getEffectiveLevel()
>= logging.CRITICAL
)
super(Tqdm, self).__init__(
iterable=iterable, disable=disable, leave=leave, **kwargs
)

def update_desc(self, desc, n=1, truncate=True):
"""
Calls `set_description(truncate(desc))` and `update(n)`
"""
self.set_description(
self.truncate(desc) if truncate else desc, refresh=False
)
self.update(n)

def update_to(self, current, total=None):
if total:
self.total = total # pylint: disable=W0613,W0201
self.update(current - self.n)

@classmethod
def truncate(cls, s, max_len=25, end=True, fill="..."):
"""
Guarantee len(output) < max_lenself.
>>> truncate("hello", 4)
'...o'
"""
if len(s) <= max_len:
return s
if len(fill) > max_len:
return fill[-max_len:] if end else fill[:max_len]
i = max_len - len(fill)
return (fill + s[-i:]) if end else (s[:i] + fill)
36 changes: 19 additions & 17 deletions dvc/remote/azure.py
Expand Up @@ -17,7 +17,7 @@
BlockBlobService = None

from dvc.utils.compat import urlparse
from dvc.progress import progress
from dvc.progress import Tqdm
from dvc.config import Config
from dvc.remote.base import RemoteBASE
from dvc.path_info import CloudURLInfo
Expand All @@ -26,14 +26,6 @@
logger = logging.getLogger(__name__)


class Callback(object):
def __init__(self, name):
self.name = name

def __call__(self, current, total):
progress.update_target(self.name, current, total)


class RemoteAZURE(RemoteBASE):
scheme = Schemes.AZURE
path_cls = CloudURLInfo
Expand Down Expand Up @@ -123,18 +115,28 @@ def list_cache_paths(self):
def _upload(
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
):
cb = None if no_progress_bar else Callback(name)
self.blob_service.create_blob_from_path(
to_info.bucket, to_info.path, from_file, progress_callback=cb
)
with Tqdm(
desc_truncate=name, disable=no_progress_bar, bytes=True
) as pbar:
self.blob_service.create_blob_from_path(
to_info.bucket,
to_info.path,
from_file,
progress_callback=pbar.update_to,
)

def _download(
self, from_info, to_file, name=None, no_progress_bar=False, **_kwargs
):
cb = None if no_progress_bar else Callback(name)
self.blob_service.get_blob_to_path(
from_info.bucket, from_info.path, to_file, progress_callback=cb
)
with Tqdm(
desc_truncate=name, disable=no_progress_bar, bytes=True
) as pbar:
self.blob_service.get_blob_to_path(
from_info.bucket,
from_info.path,
to_file,
progress_callback=pbar.update_to,
)

def exists(self, path_info):
paths = self._list_paths(path_info.bucket, path_info.path)
Expand Down

0 comments on commit 5aa9a2f

Please sign in to comment.