Skip to content

Commit

Permalink
Swapping to all_losses
Browse files Browse the repository at this point in the history
  • Loading branch information
Bharath Ramsundar authored and Bharath Ramsundar committed Jul 29, 2020
1 parent b6beb3b commit 744fede
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 41 deletions.
67 changes: 31 additions & 36 deletions deepchem/models/keras_model.py
Expand Up @@ -273,7 +273,7 @@ def fit(self,
variables: Optional[List[tf.Variable]] = None,
loss: Optional[KerasLossFn] = None,
callbacks: Union[Callable, List[Callable]] = [],
return_loss_curve: bool = False) -> Union[float, List[float]]:
all_losses: Optional[list] = None) -> float:
"""Train this model on a dataset.
Parameters
Expand Down Expand Up @@ -303,31 +303,30 @@ def fit(self,
callbacks: function or list of functions
one or more functions of the form f(model, step) that will be invoked after
every step. This can be used to perform validation, logging, etc.
return_loss_curve: bool, optional (default False)
If `True` return the full set of average losses computed over the
process of fitting. Else return the last computed average loss.
all_losses: list, optional (default False)
If specified, all logged losses are appended into this list. Note that
you can call `fit()` repeatedly with the same list and losses will
continue to be appended.
Returns
-------
Either the average loss over the most recent checkpoint interval or a list
of all such average losses over the course of fitting.
The average loss over the most recent checkpoint interval
"""
return self.fit_generator(
self.default_generator(
dataset, epochs=nb_epoch, deterministic=deterministic),
max_checkpoints_to_keep, checkpoint_interval, restore, variables, loss,
callbacks, return_loss_curve)

def fit_generator(
self,
generator: Iterable[Tuple[Any, Any, Any]],
max_checkpoints_to_keep: int = 5,
checkpoint_interval: int = 1000,
restore: bool = False,
variables: Optional[List[tf.Variable]] = None,
loss: Optional[KerasLossFn] = None,
callbacks: Union[Callable, List[Callable]] = [],
return_loss_curve: bool = False) -> Union[float, List[float]]:
dataset, epochs=nb_epoch,
deterministic=deterministic), max_checkpoints_to_keep,
checkpoint_interval, restore, variables, loss, callbacks, all_losses)

def fit_generator(self,
generator: Iterable[Tuple[Any, Any, Any]],
max_checkpoints_to_keep: int = 5,
checkpoint_interval: int = 1000,
restore: bool = False,
variables: Optional[List[tf.Variable]] = None,
loss: Optional[KerasLossFn] = None,
callbacks: Union[Callable, List[Callable]] = [],
all_losses: Optional[list] = None) -> float:
"""Train this model on data from a generator.
Parameters
Expand All @@ -353,14 +352,14 @@ def fit_generator(
callbacks: function or list of functions
one or more functions of the form f(model, step) that will be invoked after
every step. This can be used to perform validation, logging, etc.
return_loss_curve: bool, optional (default False)
If `True` return the full set of average losses computed over the
process of fitting. Else return the last computed average loss.
all_losses: list, optional (default False)
If specified, all logged losses are appended into this list. Note that
you can call `fit()` repeatedly with the same list and losses will
continue to be appended.
Returns
-------
Either the average loss over the most recent checkpoint interval or a list
of all such average losses over the course of fitting.
The average loss over the most recent checkpoint interval
"""
if not isinstance(callbacks, SequenceCollection):
callbacks = [callbacks]
Expand Down Expand Up @@ -441,13 +440,12 @@ def fit_generator(

time2 = time.time()
logger.info("TIMING: model fitting took %0.3f s" % (time2 - time1))
if return_loss_curve:
return avg_losses
if all_losses is not None:
all_losses.extend(avg_losses)
if len(avg_losses) > 0:
return avg_losses[-1]
else:
if len(avg_losses) > 0:
return avg_losses[-1]
else:
return 0.0
return 0.0

def _create_gradient_fn(self,
variables: Optional[List[tf.Variable]]) -> Callable:
Expand Down Expand Up @@ -516,17 +514,14 @@ def fit_on_batch(self,
"""
self._ensure_built()
dataset = NumpyDataset(X, y, w)
# We set return_loss_curve=False, so we know this is a float, but mypy
# can't automatically infer that.
return self.fit( # type: ignore
return self.fit(
dataset,
nb_epoch=1,
max_checkpoints_to_keep=max_checkpoints_to_keep,
checkpoint_interval=self._global_step.numpy() + 2 if checkpoint else 0,
variables=variables,
loss=loss,
callbacks=callbacks,
return_loss_curve=False)
callbacks=callbacks)

def _predict(
self, generator: Iterable[Tuple[Any, Any, Any]],
Expand Down
5 changes: 2 additions & 3 deletions deepchem/models/models.py
Expand Up @@ -21,7 +21,7 @@
from deepchem.utils.save import save_to_disk
from deepchem.utils.evaluate import Evaluator

from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence
from deepchem.utils.typing import OneOrMany

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -127,8 +127,7 @@ def save(self) -> None:
"""
raise NotImplementedError

def fit(self, dataset: Dataset,
nb_epoch: int = 10) -> Union[float, List[float]]:
def fit(self, dataset: Dataset, nb_epoch: int = 10) -> float:
"""
Fits a model on data in a Dataset object.
Expand Down
5 changes: 3 additions & 2 deletions deepchem/models/tests/test_kerasmodel.py
Expand Up @@ -58,7 +58,7 @@ def test_overfit_sequential_model():
assert scores[metric.name] > 0.9


def test_fit_return_loss_curve():
def test_fit_use_all_losses():
"""Test fitting a KerasModel and getting a loss curve back."""
n_data_points = 10
n_features = 2
Expand All @@ -74,7 +74,8 @@ def test_fit_return_loss_curve():
dc.models.losses.BinaryCrossEntropy(),
learning_rate=0.005,
log_frequency=10)
losses = model.fit(dataset, nb_epoch=1000, return_loss_curve=True)
losses = []
model.fit(dataset, nb_epoch=1000, all_losses=losses)
# Each epoch is a single step for this model
assert len(losses) == 100

Expand Down

0 comments on commit 744fede

Please sign in to comment.