Skip to content

Commit

Permalink
Improving types and simplifying return
Browse files Browse the repository at this point in the history
  • Loading branch information
Bharath Ramsundar authored and Bharath Ramsundar committed Jul 30, 2020
1 parent 8d7f70a commit 4a51afb
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
27 changes: 14 additions & 13 deletions deepchem/models/keras_model.py
Expand Up @@ -273,7 +273,7 @@ def fit(self,
variables: Optional[List[tf.Variable]] = None,
loss: Optional[KerasLossFn] = None,
callbacks: Union[Callable, List[Callable]] = [],
all_losses: Optional[list] = None) -> float:
all_losses: Optional[List[float]] = None) -> float:
"""Train this model on a dataset.
Parameters
Expand Down Expand Up @@ -303,7 +303,7 @@ def fit(self,
callbacks: function or list of functions
one or more functions of the form f(model, step) that will be invoked after
every step. This can be used to perform validation, logging, etc.
all_losses: list, optional (default False)
all_losses: Optional[List[float]], optional (default False)
If specified, all logged losses are appended into this list. Note that
you can call `fit()` repeatedly with the same list and losses will
continue to be appended.
Expand All @@ -326,7 +326,7 @@ def fit_generator(self,
variables: Optional[List[tf.Variable]] = None,
loss: Optional[KerasLossFn] = None,
callbacks: Union[Callable, List[Callable]] = [],
all_losses: Optional[list] = None) -> float:
all_losses: Optional[List[float]] = None) -> float:
"""Train this model on data from a generator.
Parameters
Expand All @@ -352,7 +352,7 @@ def fit_generator(self,
callbacks: function or list of functions
one or more functions of the form f(model, step) that will be invoked after
every step. This can be used to perform validation, logging, etc.
all_losses: list, optional (default False)
all_losses: Optional[List[float]], optional (default False)
If specified, all logged losses are appended into this list. Note that
you can call `fit()` repeatedly with the same list and losses will
continue to be appended.
Expand All @@ -367,8 +367,8 @@ def fit_generator(self,
if checkpoint_interval > 0:
manager = tf.train.CheckpointManager(self._checkpoint, self.model_dir,
max_checkpoints_to_keep)
avg_losses = []
avg_loss = 0.0
last_avg_loss = 0.0
averaged_batches = 0
train_op = None
if loss is None:
Expand Down Expand Up @@ -414,7 +414,11 @@ def fit_generator(self,
avg_loss = float(avg_loss) / averaged_batches
logger.info(
'Ending global_step %d: Average loss %g' % (current_step, avg_loss))
avg_losses.append(avg_loss)
if all_losses is not None:
all_losses.append(avg_loss)
# Capture the last avg_loss in case of return since we're resetting to
# 0 now
last_avg_loss = avg_loss
avg_loss = 0.0
averaged_batches = 0

Expand All @@ -433,19 +437,16 @@ def fit_generator(self,
avg_loss = float(avg_loss) / averaged_batches
logger.info(
'Ending global_step %d: Average loss %g' % (current_step, avg_loss))
avg_losses.append(avg_loss)
if all_losses is not None:
all_losses.append(avg_loss)
last_avg_loss = avg_loss

if checkpoint_interval > 0:
manager.save()

time2 = time.time()
logger.info("TIMING: model fitting took %0.3f s" % (time2 - time1))
if all_losses is not None:
all_losses.extend(avg_losses)
if len(avg_losses) > 0:
return avg_losses[-1]
else:
return 0.0
return last_avg_loss

def _create_gradient_fn(self,
variables: Optional[List[tf.Variable]]) -> Callable:
Expand Down
1 change: 1 addition & 0 deletions deepchem/models/tests/test_kerasmodel.py
Expand Up @@ -78,6 +78,7 @@ def test_fit_use_all_losses():
model.fit(dataset, nb_epoch=1000, all_losses=losses)
# Each epoch is a single step for this model
assert len(losses) == 100
assert np.count_nonzero(np.array(losses)) == 100


def test_fit_on_batch():
Expand Down

0 comments on commit 4a51afb

Please sign in to comment.