From 4a51afbfa5000c51cc076f81721fbd575bfb9ae9 Mon Sep 17 00:00:00 2001 From: Bharath Ramsundar Date: Thu, 30 Jul 2020 12:17:30 -0700 Subject: [PATCH] Improving types and simplifying return --- deepchem/models/keras_model.py | 27 ++++++++++++------------ deepchem/models/tests/test_kerasmodel.py | 1 + 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/deepchem/models/keras_model.py b/deepchem/models/keras_model.py index 5a9bf3f074..dc1c3444ab 100644 --- a/deepchem/models/keras_model.py +++ b/deepchem/models/keras_model.py @@ -273,7 +273,7 @@ def fit(self, variables: Optional[List[tf.Variable]] = None, loss: Optional[KerasLossFn] = None, callbacks: Union[Callable, List[Callable]] = [], - all_losses: Optional[list] = None) -> float: + all_losses: Optional[List[float]] = None) -> float: """Train this model on a dataset. Parameters @@ -303,7 +303,7 @@ def fit(self, callbacks: function or list of functions one or more functions of the form f(model, step) that will be invoked after every step. This can be used to perform validation, logging, etc. - all_losses: list, optional (default False) + all_losses: Optional[List[float]], optional (default False) If specified, all logged losses are appended into this list. Note that you can call `fit()` repeatedly with the same list and losses will continue to be appended. @@ -326,7 +326,7 @@ def fit_generator(self, variables: Optional[List[tf.Variable]] = None, loss: Optional[KerasLossFn] = None, callbacks: Union[Callable, List[Callable]] = [], - all_losses: Optional[list] = None) -> float: + all_losses: Optional[List[float]] = None) -> float: """Train this model on data from a generator. Parameters @@ -352,7 +352,7 @@ def fit_generator(self, callbacks: function or list of functions one or more functions of the form f(model, step) that will be invoked after every step. This can be used to perform validation, logging, etc. - all_losses: list, optional (default False) + all_losses: Optional[List[float]], optional (default False) If specified, all logged losses are appended into this list. Note that you can call `fit()` repeatedly with the same list and losses will continue to be appended. @@ -367,8 +367,8 @@ def fit_generator(self, if checkpoint_interval > 0: manager = tf.train.CheckpointManager(self._checkpoint, self.model_dir, max_checkpoints_to_keep) - avg_losses = [] avg_loss = 0.0 + last_avg_loss = 0.0 averaged_batches = 0 train_op = None if loss is None: @@ -414,7 +414,11 @@ def fit_generator(self, avg_loss = float(avg_loss) / averaged_batches logger.info( 'Ending global_step %d: Average loss %g' % (current_step, avg_loss)) - avg_losses.append(avg_loss) + if all_losses is not None: + all_losses.append(avg_loss) + # Capture the last avg_loss in case of return since we're resetting to + # 0 now + last_avg_loss = avg_loss avg_loss = 0.0 averaged_batches = 0 @@ -433,19 +437,16 @@ def fit_generator(self, avg_loss = float(avg_loss) / averaged_batches logger.info( 'Ending global_step %d: Average loss %g' % (current_step, avg_loss)) - avg_losses.append(avg_loss) + if all_losses is not None: + all_losses.append(avg_loss) + last_avg_loss = avg_loss if checkpoint_interval > 0: manager.save() time2 = time.time() logger.info("TIMING: model fitting took %0.3f s" % (time2 - time1)) - if all_losses is not None: - all_losses.extend(avg_losses) - if len(avg_losses) > 0: - return avg_losses[-1] - else: - return 0.0 + return last_avg_loss def _create_gradient_fn(self, variables: Optional[List[tf.Variable]]) -> Callable: diff --git a/deepchem/models/tests/test_kerasmodel.py b/deepchem/models/tests/test_kerasmodel.py index 5ddaa6caf2..d84213662e 100644 --- a/deepchem/models/tests/test_kerasmodel.py +++ b/deepchem/models/tests/test_kerasmodel.py @@ -78,6 +78,7 @@ def test_fit_use_all_losses(): model.fit(dataset, nb_epoch=1000, all_losses=losses) # Each epoch is a single step for this model assert len(losses) == 100 + assert np.count_nonzero(np.array(losses)) == 100 def test_fit_on_batch():