Skip to content

Commit

Permalink
Add print to Loop. #4.
Browse files Browse the repository at this point in the history
Add attention mechanism.
  • Loading branch information
lizeyan committed Sep 7, 2018
1 parent 67bc059 commit b1b9ef9
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions snippets/scaffold/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,19 @@ def submit_data(self, name, value):
else:
raise RuntimeError("Can't submit data outside epoch or step")

def print(self, string):
if self._max_epochs is None:
epoch_str = "{}".format(self._epoch_cnt)
else:
epoch_str = "{}/{}".format(self._epoch_cnt, self._max_epochs)
if self._max_steps is None:
step_str = "{}".format(self._step_cnt)
else:
step_str = "{}/{}".format(self._step_cnt, self._max_steps)
process_str = "[epoch:{} step:{} ETA:{:.3f}s]".format(epoch_str, step_str,
self._eta())
self._print_fn("{} {}".format(process_str, string))

def _print_log(self, unit: str):
if self._print_fn is None:
return
Expand All @@ -212,17 +225,7 @@ def _print_log(self, unit: str):
metric_str_list.append(metric.format(item))
metric_str = " ".join(metric_str_list)

if self._max_epochs is None:
epoch_str = "{}".format(self._epoch_cnt)
else:
epoch_str = "{}/{}".format(self._epoch_cnt, self._max_epochs)
if self._max_steps is None:
step_str = "{}".format(self._step_cnt)
else:
step_str = "{}/{}".format(self._step_cnt, self._max_steps)
process_str = "[epoch:{} step:{} ETA:{:.3f}s]".format(epoch_str, step_str,
self._eta())
self._print_fn("{} {}".format(process_str, metric_str))
self.print(metric_str)


TrainLoop = Loop
Expand Down

0 comments on commit b1b9ef9

Please sign in to comment.