/
progress_bar.py
128 lines (103 loc) · 4.3 KB
/
progress_bar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from __future__ import division
import datetime
import os
import sys
import time
from chainer.training import extension
from chainer.training.extensions import util
class ProgressBar(extension.Extension):
"""Trainer extension to print a progress bar and recent training status.
This extension prints a progress bar at every call. It watches the current
iteration and epoch to print the bar.
Args:
training_length (tuple): Length of whole training. It consists of an
integer and either ``'epoch'`` or ``'iteration'``. If this value is
omitted and the stop trigger of the trainer is
:class:`IntervalTrigger`, this extension uses its attributes to
determine the length of the training.
update_interval (int): Number of iterations to skip printing the
progress bar.
bar_length (int): Length of the progress bar in characters.
out: Stream to print the bar. Standard output is used by default.
"""
def __init__(self, training_length=None, update_interval=100,
bar_length=50, out=sys.stdout):
self._training_length = training_length
self._status_template = None
self._update_interval = update_interval
self._bar_length = bar_length
self._out = out
self._recent_timing = []
def __call__(self, trainer):
training_length = self._training_length
# initialize some attributes at the first call
if training_length is None:
t = trainer.stop_trigger
training_length = t.get_training_length()
stat_template = self._status_template
if stat_template is None:
stat_template = self._status_template = (
'{0.iteration:10} iter, {0.epoch} epoch / %s %ss\n' %
training_length)
length, unit = training_length
out = self._out
iteration = trainer.updater.iteration
# print the progress bar
if iteration % self._update_interval == 0:
epoch = trainer.updater.epoch_detail
recent_timing = self._recent_timing
now = time.time()
recent_timing.append((iteration, epoch, now))
if os.name == 'nt':
util.erase_console(0, 0)
else:
out.write('\033[J')
if unit == 'iteration':
rate = iteration / length
else:
rate = epoch / length
rate = min(rate, 1.0)
bar_length = self._bar_length
marks = '#' * int(rate * bar_length)
out.write(' total [{}{}] {:6.2%}\n'.format(
marks, '.' * (bar_length - len(marks)), rate))
epoch_rate = epoch - int(epoch)
marks = '#' * int(epoch_rate * bar_length)
out.write('this epoch [{}{}] {:6.2%}\n'.format(
marks, '.' * (bar_length - len(marks)), epoch_rate))
status = stat_template.format(trainer.updater)
out.write(status)
old_t, old_e, old_sec = recent_timing[0]
span = now - old_sec
if span != 0:
speed_t = (iteration - old_t) / span
speed_e = (epoch - old_e) / span
else:
speed_t = float('inf')
speed_e = float('inf')
if unit == 'iteration':
estimated_time = (length - iteration) / speed_t
else:
estimated_time = (length - epoch) / speed_e
estimated_time = max(estimated_time, 0.0)
out.write('{:10.5g} iters/sec. Estimated time to finish: {}.\n'
.format(speed_t,
datetime.timedelta(seconds=estimated_time)))
# move the cursor to the head of the progress bar
if os.name == 'nt':
util.set_console_cursor_position(0, -4)
else:
out.write('\033[4A')
if hasattr(out, 'flush'):
out.flush()
if len(recent_timing) > 100:
del recent_timing[0]
def finalize(self):
# delete the progress bar
out = self._out
if os.name == 'nt':
util.erase_console(0, 0)
else:
out.write('\033[J')
if hasattr(out, 'flush'):
out.flush()