diff --git a/submission_runner.py b/submission_runner.py index 2289d39d3..717ea2dc4 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -17,6 +17,7 @@ import datetime import gc import importlib +import itertools import json import os import struct @@ -133,6 +134,14 @@ flags.DEFINE_boolean('save_checkpoints', True, 'Whether or not to checkpoint the model at every eval.') +flags.DEFINE_integer( + 'hparam_start_index', + None, + 'Start index to slice set of hyperparameters in tuning search space.') +flags.DEFINE_integer( + 'hparam_end_index', + None, + 'End index to slice set of hyperparameters in tuning spearch space.') flags.DEFINE_integer( 'rng_seed', None, @@ -455,6 +464,8 @@ def score_submission_on_workload(workload: spec.Workload, num_tuning_trials: Optional[int] = None, log_dir: Optional[str] = None, save_checkpoints: Optional[bool] = True, + hparam_start_index: Optional[bool] = None, + hparam_end_index: Optional[bool] = None, rng_seed: Optional[int] = None): # Expand paths because '~' may not be recognized data_dir = os.path.expanduser(data_dir) @@ -500,7 +511,9 @@ def score_submission_on_workload(workload: spec.Workload, json.load(search_space_file), num_tuning_trials) all_timings = [] all_metrics = [] - for hi, hyperparameters in enumerate(tuning_search_space): + tuning_search_space_iter = itertools.islice( + enumerate(tuning_search_space), hparam_start_index, hparam_end_index) + for hi, hyperparameters in tuning_search_space_iter: # Generate a new seed from hardware sources of randomness for each trial. if not rng_seed: rng_seed = struct.unpack('I', os.urandom(4))[0] @@ -545,7 +558,7 @@ def score_submission_on_workload(workload: spec.Workload, all_timings.append(timing) all_metrics.append(metrics) score = min(all_timings) - for ti in range(num_tuning_trials): + for ti, _ in tuning_search_space_iter: logging.info(f'Tuning trial {ti + 1}/{num_tuning_trials}') logging.info(f'Hyperparameters: {tuning_search_space[ti]}') logging.info(f'Metrics: {all_metrics[ti]}') @@ -621,6 +634,8 @@ def main(_): num_tuning_trials=FLAGS.num_tuning_trials, log_dir=logging_dir_path, save_checkpoints=FLAGS.save_checkpoints, + hparam_start_index=FLAGS.hparam_start_index, + hparam_end_index=FLAGS.hparam_end_index, rng_seed=FLAGS.rng_seed) logging.info(f'Final {FLAGS.workload} score: {score}')