diff --git a/deepchem/models/keras_model.py b/deepchem/models/keras_model.py index 2a8d92951c..5a9bf3f074 100644 --- a/deepchem/models/keras_model.py +++ b/deepchem/models/keras_model.py @@ -273,7 +273,7 @@ def fit(self, variables: Optional[List[tf.Variable]] = None, loss: Optional[KerasLossFn] = None, callbacks: Union[Callable, List[Callable]] = [], - return_loss_curve: bool = False) -> Union[float, List[float]]: + all_losses: Optional[list] = None) -> float: """Train this model on a dataset. Parameters @@ -303,31 +303,30 @@ def fit(self, callbacks: function or list of functions one or more functions of the form f(model, step) that will be invoked after every step. This can be used to perform validation, logging, etc. - return_loss_curve: bool, optional (default False) - If `True` return the full set of average losses computed over the - process of fitting. Else return the last computed average loss. + all_losses: list, optional (default False) + If specified, all logged losses are appended into this list. Note that + you can call `fit()` repeatedly with the same list and losses will + continue to be appended. Returns ------- - Either the average loss over the most recent checkpoint interval or a list - of all such average losses over the course of fitting. + The average loss over the most recent checkpoint interval """ return self.fit_generator( self.default_generator( - dataset, epochs=nb_epoch, deterministic=deterministic), - max_checkpoints_to_keep, checkpoint_interval, restore, variables, loss, - callbacks, return_loss_curve) - - def fit_generator( - self, - generator: Iterable[Tuple[Any, Any, Any]], - max_checkpoints_to_keep: int = 5, - checkpoint_interval: int = 1000, - restore: bool = False, - variables: Optional[List[tf.Variable]] = None, - loss: Optional[KerasLossFn] = None, - callbacks: Union[Callable, List[Callable]] = [], - return_loss_curve: bool = False) -> Union[float, List[float]]: + dataset, epochs=nb_epoch, + deterministic=deterministic), max_checkpoints_to_keep, + checkpoint_interval, restore, variables, loss, callbacks, all_losses) + + def fit_generator(self, + generator: Iterable[Tuple[Any, Any, Any]], + max_checkpoints_to_keep: int = 5, + checkpoint_interval: int = 1000, + restore: bool = False, + variables: Optional[List[tf.Variable]] = None, + loss: Optional[KerasLossFn] = None, + callbacks: Union[Callable, List[Callable]] = [], + all_losses: Optional[list] = None) -> float: """Train this model on data from a generator. Parameters @@ -353,14 +352,14 @@ def fit_generator( callbacks: function or list of functions one or more functions of the form f(model, step) that will be invoked after every step. This can be used to perform validation, logging, etc. - return_loss_curve: bool, optional (default False) - If `True` return the full set of average losses computed over the - process of fitting. Else return the last computed average loss. + all_losses: list, optional (default False) + If specified, all logged losses are appended into this list. Note that + you can call `fit()` repeatedly with the same list and losses will + continue to be appended. Returns ------- - Either the average loss over the most recent checkpoint interval or a list - of all such average losses over the course of fitting. + The average loss over the most recent checkpoint interval """ if not isinstance(callbacks, SequenceCollection): callbacks = [callbacks] @@ -441,13 +440,12 @@ def fit_generator( time2 = time.time() logger.info("TIMING: model fitting took %0.3f s" % (time2 - time1)) - if return_loss_curve: - return avg_losses + if all_losses is not None: + all_losses.extend(avg_losses) + if len(avg_losses) > 0: + return avg_losses[-1] else: - if len(avg_losses) > 0: - return avg_losses[-1] - else: - return 0.0 + return 0.0 def _create_gradient_fn(self, variables: Optional[List[tf.Variable]]) -> Callable: @@ -516,17 +514,14 @@ def fit_on_batch(self, """ self._ensure_built() dataset = NumpyDataset(X, y, w) - # We set return_loss_curve=False, so we know this is a float, but mypy - # can't automatically infer that. - return self.fit( # type: ignore + return self.fit( dataset, nb_epoch=1, max_checkpoints_to_keep=max_checkpoints_to_keep, checkpoint_interval=self._global_step.numpy() + 2 if checkpoint else 0, variables=variables, loss=loss, - callbacks=callbacks, - return_loss_curve=False) + callbacks=callbacks) def _predict( self, generator: Iterable[Tuple[Any, Any, Any]], diff --git a/deepchem/models/models.py b/deepchem/models/models.py index 156f92cefe..ae7c0b5763 100644 --- a/deepchem/models/models.py +++ b/deepchem/models/models.py @@ -21,7 +21,7 @@ from deepchem.utils.save import save_to_disk from deepchem.utils.evaluate import Evaluator -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence from deepchem.utils.typing import OneOrMany logger = logging.getLogger(__name__) @@ -127,8 +127,7 @@ def save(self) -> None: """ raise NotImplementedError - def fit(self, dataset: Dataset, - nb_epoch: int = 10) -> Union[float, List[float]]: + def fit(self, dataset: Dataset, nb_epoch: int = 10) -> float: """ Fits a model on data in a Dataset object. diff --git a/deepchem/models/tests/test_kerasmodel.py b/deepchem/models/tests/test_kerasmodel.py index 2ecac55e92..5ddaa6caf2 100644 --- a/deepchem/models/tests/test_kerasmodel.py +++ b/deepchem/models/tests/test_kerasmodel.py @@ -58,7 +58,7 @@ def test_overfit_sequential_model(): assert scores[metric.name] > 0.9 -def test_fit_return_loss_curve(): +def test_fit_use_all_losses(): """Test fitting a KerasModel and getting a loss curve back.""" n_data_points = 10 n_features = 2 @@ -74,7 +74,8 @@ def test_fit_return_loss_curve(): dc.models.losses.BinaryCrossEntropy(), learning_rate=0.005, log_frequency=10) - losses = model.fit(dataset, nb_epoch=1000, return_loss_curve=True) + losses = [] + model.fit(dataset, nb_epoch=1000, all_losses=losses) # Each epoch is a single step for this model assert len(losses) == 100