Skip to content

Commit

Permalink
Merge 0a51ace into 9096263
Browse files Browse the repository at this point in the history
  • Loading branch information
rbharath committed Jun 25, 2020
2 parents 9096263 + 0a51ace commit 3f3db04
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
15 changes: 10 additions & 5 deletions deepchem/models/keras_model.py
Expand Up @@ -104,7 +104,7 @@ def __init__(self,
learning_rate=0.001,
optimizer=None,
tensorboard=False,
tensorboard_log_frequency=100,
log_frequency=100,
**kwargs):
"""Create a new KerasModel.
Expand All @@ -130,8 +130,13 @@ def __init__(self,
ignored.
tensorboard: bool
whether to log progress to TensorBoard during training
tensorboard_log_frequency: int
the frequency at which to log data to TensorBoard, measured in batches
log_frequency: int
The frequency at which to log data. Data is logged using
`logging` by default. If `tensorboard` is set, data is also
logged to TensorBoard. Logging happens at global steps. Roughly,
a global step corresponds to one batch of training. If you'd
like a printout every 10 batch steps, you'd set
`log_frequency=10` for example.
"""
super(KerasModel, self).__init__(
model_instance=model, model_dir=model_dir, **kwargs)
Expand All @@ -146,7 +151,7 @@ def __init__(self,
else:
self.optimizer = optimizer
self.tensorboard = tensorboard
self.tensorboard_log_frequency = tensorboard_log_frequency
self.log_frequency = log_frequency
if self.tensorboard:
self._summary_writer = tf.summary.create_file_writer(self.model_dir)
if output_types is None:
Expand Down Expand Up @@ -348,7 +353,7 @@ def fit_generator(self,

# Report progress and write checkpoints.
averaged_batches += 1
should_log = (current_step % self.tensorboard_log_frequency == 0)
should_log = (current_step % self.log_frequency == 0)
if should_log:
avg_loss = float(avg_loss) / averaged_batches
logger.info(
Expand Down
2 changes: 1 addition & 1 deletion deepchem/models/tests/test_kerasmodel.py
Expand Up @@ -235,7 +235,7 @@ def test_tensorboard(self):
keras_model,
dc.models.losses.CategoricalCrossEntropy(),
tensorboard=True,
tensorboard_log_frequency=1)
log_frequency=1)
model.fit(dataset, nb_epoch=10)
files_in_dir = os.listdir(model.model_dir)
event_file = list(filter(lambda x: x.startswith("events"), files_in_dir))
Expand Down

0 comments on commit 3f3db04

Please sign in to comment.