Skip to content

Commit

Permalink
Update tuner.py
Browse files Browse the repository at this point in the history
Resolves #642
  • Loading branch information
haifeng-jin committed Feb 28, 2022
1 parent be68257 commit d3c18dc
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions keras_tuner/engine/tuner.py
Expand Up @@ -26,6 +26,7 @@

from keras_tuner import config as config_module
from keras_tuner.engine import base_tuner
from keras_tuner.engine import objective as obj_module
from keras_tuner.engine import tuner_utils

MAX_FAIL_STREAK = 5
Expand Down Expand Up @@ -275,10 +276,14 @@ def run_trial(self, trial, *args, **kwargs):
"""
# Not using `ModelCheckpoint` to support MultiObjective.
# It can only track one of the metrics to save the best model.
model_checkpoint = tuner_utils.SaveBestEpoch(
objective=self.oracle.objective,
filepath=self._get_checkpoint_fname(trial.trial_id),
)
filepath = self._get_checkpoint_fname(trial.trial_id)
if isinstance(self.oracle.objective, obj_module.DefaultObjective):
model_checkpoint = keras.callbacks.SaveModelCheckpoint(filepath)
else:
model_checkpoint = tuner_utils.SaveBestEpoch(
objective=self.oracle.objective,
filepath=filepath,
)
original_callbacks = kwargs.pop("callbacks", [])

# Run the training process multiple times.
Expand Down

0 comments on commit d3c18dc

Please sign in to comment.