Skip to content

Commit

Permalink
add param to only save the best checkpoint in ray tune hyperopt exp (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ANarayan authored and ShreyaR committed Aug 17, 2021
1 parent fb02a2f commit e8981ee
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions ludwig/hyperopt/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def _has_eval_metric(self, stats):

def get_metric_score(self, train_stats, eval_stats) -> float:
if self._has_metric(train_stats, TEST):
logger.info("Returning metric score from training (test) statistics")
logger.info(
"Returning metric score from training (test) statistics")
return self.get_metric_score_from_train_stats(train_stats, TEST)
elif self._has_eval_metric(eval_stats):
logger.info("Returning metric score from eval statistics. "
Expand All @@ -86,7 +87,8 @@ def get_metric_score(self, train_stats, eval_stats) -> float:
"best validation performance")
return self.get_metric_score_from_eval_stats(eval_stats)
elif self._has_metric(train_stats, VALIDATION):
logger.info("Returning metric score from training (validation) statistics")
logger.info(
"Returning metric score from training (validation) statistics")
return self.get_metric_score_from_train_stats(train_stats, VALIDATION)
elif self._has_metric(train_stats, TRAINING):
logger.info("Returning metric score from training split statistics, "
Expand Down Expand Up @@ -948,6 +950,7 @@ def run_experiment_trial(config, checkpoint_dir=None):
scheduler=self.scheduler,
search_alg=search_alg,
num_samples=self.num_samples,
keep_checkpoints_num=1,
resources_per_trial=resources_per_trial,
time_budget_s=self.time_budget_s,
queue_trials=True,
Expand Down

0 comments on commit e8981ee

Please sign in to comment.