Skip to content

Commit

Permalink
Merge pull request #2060 from deepchem/keras_loss
Browse files Browse the repository at this point in the history
Allow for reporting of loss curve from KerasModel.fit
  • Loading branch information
Bharath Ramsundar committed Jul 31, 2020
2 parents b1e6316 + 3f45128 commit 5984e9e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 7 deletions.
31 changes: 25 additions & 6 deletions deepchem/models/keras_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ def fit(self,
restore: bool = False,
variables: Optional[List[tf.Variable]] = None,
loss: Optional[KerasLossFn] = None,
callbacks: Union[Callable, List[Callable]] = []) -> float:
callbacks: Union[Callable, List[Callable]] = [],
all_losses: Optional[List[float]] = None) -> float:
"""Train this model on a dataset.
Parameters
Expand Down Expand Up @@ -302,16 +303,20 @@ def fit(self,
callbacks: function or list of functions
one or more functions of the form f(model, step) that will be invoked after
every step. This can be used to perform validation, logging, etc.
all_losses: Optional[List[float]], optional (default None)
If specified, all logged losses are appended into this list. Note that
you can call `fit()` repeatedly with the same list and losses will
continue to be appended.
Returns
-------
the average loss over the most recent checkpoint interval
The average loss over the most recent checkpoint interval
"""
return self.fit_generator(
self.default_generator(
dataset, epochs=nb_epoch,
deterministic=deterministic), max_checkpoints_to_keep,
checkpoint_interval, restore, variables, loss, callbacks)
checkpoint_interval, restore, variables, loss, callbacks, all_losses)

def fit_generator(self,
generator: Iterable[Tuple[Any, Any, Any]],
Expand All @@ -320,7 +325,8 @@ def fit_generator(self,
restore: bool = False,
variables: Optional[List[tf.Variable]] = None,
loss: Optional[KerasLossFn] = None,
callbacks: Union[Callable, List[Callable]] = []) -> float:
callbacks: Union[Callable, List[Callable]] = [],
all_losses: Optional[List[float]] = None) -> float:
"""Train this model on data from a generator.
Parameters
Expand All @@ -346,10 +352,14 @@ def fit_generator(self,
callbacks: function or list of functions
one or more functions of the form f(model, step) that will be invoked after
every step. This can be used to perform validation, logging, etc.
all_losses: Optional[List[float]], optional (default None)
If specified, all logged losses are appended into this list. Note that
you can call `fit()` repeatedly with the same list and losses will
continue to be appended.
Returns
-------
the average loss over the most recent checkpoint interval
The average loss over the most recent checkpoint interval
"""
if not isinstance(callbacks, SequenceCollection):
callbacks = [callbacks]
Expand All @@ -358,6 +368,7 @@ def fit_generator(self,
manager = tf.train.CheckpointManager(self._checkpoint, self.model_dir,
max_checkpoints_to_keep)
avg_loss = 0.0
last_avg_loss = 0.0
averaged_batches = 0
train_op = None
if loss is None:
Expand Down Expand Up @@ -403,6 +414,11 @@ def fit_generator(self,
avg_loss = float(avg_loss) / averaged_batches
logger.info(
'Ending global_step %d: Average loss %g' % (current_step, avg_loss))
if all_losses is not None:
all_losses.append(avg_loss)
# Capture the last avg_loss in case of return since we're resetting to
# 0 now
last_avg_loss = avg_loss
avg_loss = 0.0
averaged_batches = 0

Expand All @@ -421,13 +437,16 @@ def fit_generator(self,
avg_loss = float(avg_loss) / averaged_batches
logger.info(
'Ending global_step %d: Average loss %g' % (current_step, avg_loss))
if all_losses is not None:
all_losses.append(avg_loss)
last_avg_loss = avg_loss

if checkpoint_interval > 0:
manager.save()

time2 = time.time()
logger.info("TIMING: model fitting took %0.3f s" % (time2 - time1))
return avg_loss
return last_avg_loss

def _create_gradient_fn(self,
variables: Optional[List[tf.Variable]]) -> Callable:
Expand Down
2 changes: 1 addition & 1 deletion deepchem/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def fit(self, dataset: Dataset, nb_epoch: int = 10) -> float:
Returns
-------
the average loss over the most recent epoch
The average loss over the most recent checkpoint interval.
"""
for epoch in range(nb_epoch):
logger.info("Starting epoch %s" % str(epoch + 1))
Expand Down
23 changes: 23 additions & 0 deletions deepchem/models/tests/test_kerasmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,29 @@ def test_overfit_sequential_model():
assert scores[metric.name] > 0.9


def test_fit_use_all_losses():
"""Test fitting a KerasModel and getting a loss curve back."""
n_data_points = 10
n_features = 2
X = np.random.rand(n_data_points, n_features)
y = (X[:, 0] > X[:, 1]).astype(np.float32)
dataset = dc.data.NumpyDataset(X, y)
keras_model = tf.keras.Sequential([
tf.keras.layers.Dense(10, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model = dc.models.KerasModel(
keras_model,
dc.models.losses.BinaryCrossEntropy(),
learning_rate=0.005,
log_frequency=10)
losses = []
model.fit(dataset, nb_epoch=1000, all_losses=losses)
# Each epoch is a single step for this model
assert len(losses) == 100
assert np.count_nonzero(np.array(losses)) == 100


def test_fit_on_batch():
"""Test fitting a KerasModel to individual batches."""
n_data_points = 10
Expand Down

0 comments on commit 5984e9e

Please sign in to comment.