From 63adaf04d2ffff8dec299623627d55d4bacac598 Mon Sep 17 00:00:00 2001 From: Haowen Xu Date: Wed, 11 Sep 2019 02:47:52 +0800 Subject: [PATCH] add "ensure_variables_initialized" to Trainer --- tfsnippet/trainer/base_trainer.py | 8 ++++++-- tfsnippet/trainer/trainer.py | 10 ++++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/tfsnippet/trainer/base_trainer.py b/tfsnippet/trainer/base_trainer.py index 2e485d80..6ccbee85 100644 --- a/tfsnippet/trainer/base_trainer.py +++ b/tfsnippet/trainer/base_trainer.py @@ -79,14 +79,17 @@ class BaseTrainer(object): trainer.log_after_steps(1000) # call `loop.print_logs` every 1000 steps """ - def __init__(self, loop): + def __init__(self, loop, ensure_variables_initialized=True): """ Initialize the internal states of :class:`BaseTrainer`. Args: loop (TrainLoop): The training loop object. + ensure_variables_initialized (bool): Whether or not to ensure + the variables are initialized in :meth:`run()`? """ self._loop = loop + self._ensure_variables_initialized = ensure_variables_initialized self._events = EventSource([ EventKeys.BEFORE_EXECUTION, EventKeys.AFTER_EXECUTION, @@ -134,7 +137,8 @@ def run(self): # initialize global training status session = get_default_session_or_error() - ensure_variables_initialized() + if self._ensure_variables_initialized: + ensure_variables_initialized() self.loop.print_training_summary() for _ in self.loop.iter_epochs(): diff --git a/tfsnippet/trainer/trainer.py b/tfsnippet/trainer/trainer.py index 95d50b23..6acbda18 100644 --- a/tfsnippet/trainer/trainer.py +++ b/tfsnippet/trainer/trainer.py @@ -65,7 +65,8 @@ class Trainer(BaseTrainer): """ def __init__(self, loop, train_op, inputs, data_flow, feed_dict=None, - metrics=None, summaries=None): + metrics=None, summaries=None, + ensure_variables_initialized=True): """ Args: @@ -90,13 +91,18 @@ def __init__(self, loop, train_op, inputs, data_flow, feed_dict=None, of summaries to be run and along with `train_op`, and later to be added to ``loop.summary_writer``. If ``loop.summary_writer`` is None, then no summary will be run. + ensure_variables_initialized (bool): Whether or not to ensure + the variables are initialized in :meth:`run()`? """ if loop.max_epoch is None and loop.max_step is None: raise ValueError('At least one of `max_epoch`, `max_step` should ' 'be configured for `loop`.') if summaries is not None and is_tensor_object(summaries): summaries = [summaries] - super(Trainer, self).__init__(loop=loop) + super(Trainer, self).__init__( + loop=loop, + ensure_variables_initialized=ensure_variables_initialized + ) # memorize the arguments self._inputs = tuple(inputs or ())