Skip to content

Commit

Permalink
fixed printed string readability. added ETA
Browse files Browse the repository at this point in the history
  • Loading branch information
justanhduc committed Apr 7, 2022
1 parent b4273c0 commit 19f3b14
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 16 deletions.
47 changes: 32 additions & 15 deletions neural_monitor/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,20 @@ def standardize_image(img):
return img


def _convert_time_human_readable(t):
if t < 3600:
time_unit = 'mins'
t /= 60.
elif 86400 > t >= 3600:
time_unit = 'hrs'
t /= 3600.
else:
time_unit = 'days'
t /= 86400

return t, time_unit


class Monitor:
"""
Collects statistics and displays the results using various backends.
Expand Down Expand Up @@ -319,6 +333,7 @@ def __init__(self):
self._num_iters = None
self.print_freq = 1
self.num_iters = None
self.num_epochs = None
self.use_tensorboard = None
self.current_folder = None
self.plot_folder = None
Expand Down Expand Up @@ -371,7 +386,7 @@ def __setattr__(self, attr, val):

def initialize(self, model_name: Optional[str] = None, root: Optional[str] = None,
current_folder: Optional[str] = None, print_freq: Optional[int] = 1,
num_iters: Optional[int] = None, prefix: Optional[str] = None,
num_iters: Optional[int] = None, num_epochs: Optional[int] = None, prefix: Optional[str] = None,
use_tensorboard: Optional[bool] = True, with_git: Optional[bool] = False,
not_found_warn: bool = True) -> None:
"""
Expand All @@ -398,6 +413,10 @@ def initialize(self, model_name: Optional[str] = None, root: Optional[str] = Non
number of iterations per epoch.
If not provided, it will be calculated after one epoch.
Default: ``None``.
:param num_epochs:
total number of epochs.
If provided, ETA will be shown.
Default: ``None``.
:param prefix:
a common prefix that is shared between folder names of different runs.
Default: ``'run'``.
Expand All @@ -424,6 +443,7 @@ def initialize(self, model_name: Optional[str] = None, root: Optional[str] = Non
self._num_iters = num_iters
self.print_freq = print_freq
self.num_iters = num_iters
self.num_epochs = num_epochs
self.use_tensorboard = use_tensorboard
self.current_folder = os.path.abspath(current_folder) if current_folder is not None else None
self.with_git = with_git
Expand Down Expand Up @@ -1447,24 +1467,21 @@ def _flush(self):
f.close()
lock.release_write()

iter_show = 'Epoch {} Iteration {}/{} ({:.2f}%)'.format(
epoch + 1, it % self.num_iters, self.num_iters,
(it % self.num_iters) / self.num_iters * 100.) if self.num_iters \
else 'Epoch {} Iteration {}'.format(epoch + 1, it)
it_percentage = (it % self.num_iters) / self.num_iters if self.num_iters else None
iter_show = f'Epoch {epoch + 1} Iteration {it % self.num_iters}/{self.num_iters} ' \
f'({it_percentage * 100:.2f}%)' if self.num_iters else f'Epoch {epoch + 1} Iteration {it}'

elapsed_time = time.time() - self._timer
if elapsed_time < 3600:
time_unit = 'mins'
elapsed_time /= 60.
elif 86400 > elapsed_time >= 3600:
time_unit = 'hrs'
elapsed_time /= 3600.
if self.num_iters and self.num_epochs:
eta = elapsed_time / (epoch + it_percentage + 1e-8) * (self.num_epochs - (epoch + it_percentage))
eta, eta_unit = _convert_time_human_readable(eta)
eta_str = f'ETA {eta:.2f}{eta_unit}'
else:
time_unit = 'days'
elapsed_time /= 86400
eta_str = f'ETA N/A'

elapsed_time_str = '{:.2f}'.format(elapsed_time) + time_unit
log = 'Elapsed time {} {}\t{}\t{}'.format(elapsed_time_str, self.current_run, iter_show, '\t'.join(prints))
elapsed_time, elapsed_time_unit = _convert_time_human_readable(elapsed_time)
elapsed_time_str = f'{elapsed_time:.2f}{elapsed_time_unit}'
log = f'{self.current_run}\t Elapsed time {elapsed_time_str} ({eta_str})\t{iter_show}\t' + '\t'.join(prints)
root_logger.info(log)
self._q.task_done()

Expand Down
2 changes: 1 addition & 1 deletion neural_monitor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import threading
import numpy as np

log_formatter = logging.Formatter('%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s')
log_formatter = logging.Formatter('%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s\n')
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)

Expand Down

0 comments on commit 19f3b14

Please sign in to comment.