diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index af2e61581..2b3cf86f6 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -275,6 +275,12 @@ def get_meta_data(workload: spec.Workload) -> dict: return meta_data +def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str): + meta_data = get_meta_data(workload) + meta_data.update({'rng_seed': rng_seed}) + write_json(meta_file_name, meta_data) + + class MetricLogger(object): """Used to log all measurements during training. diff --git a/setup.cfg b/setup.cfg index 6f53cd51b..a7ce5ebb2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -115,6 +115,7 @@ jax_core_deps = # Todo(kasimbeg): verify if this is necessary after we # upgrade jax. chex==0.1.7 + ml_dtypes==0.2.0 # JAX CPU jax_cpu = diff --git a/submission_runner.py b/submission_runner.py index f4ee32ede..2289d39d3 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -133,6 +133,11 @@ flags.DEFINE_boolean('save_checkpoints', True, 'Whether or not to checkpoint the model at every eval.') +flags.DEFINE_integer( + 'rng_seed', + None, + 'Value of rng seed. If None, a random seed will' + 'be generated from hardware.') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -173,6 +178,7 @@ def train_once( update_params: spec.UpdateParamsFn, data_selection: spec.DataSelectionFn, hyperparameters: Optional[spec.Hyperparameters], + rng_seed: int, rng: spec.RandomState, profiler: Profiler, max_global_steps: int = None, @@ -267,10 +273,9 @@ def train_once( global_step, preemption_count, checkpoint_dir=log_dir) - meta_data = logger_utils.get_meta_data(workload) meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') - logger_utils.write_json(meta_file_name, meta_data) + logger_utils.save_meta_data(workload, rng_seed, preemption_count) flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json') logging.info(f'Saving flags to {flag_file_name}.') logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict()) @@ -449,7 +454,8 @@ def score_submission_on_workload(workload: spec.Workload, tuning_search_space: Optional[str] = None, num_tuning_trials: Optional[int] = None, log_dir: Optional[str] = None, - save_checkpoints: Optional[bool] = True): + save_checkpoints: Optional[bool] = True, + rng_seed: Optional[int] = None): # Expand paths because '~' may not be recognized data_dir = os.path.expanduser(data_dir) if imagenet_v2_data_dir: @@ -496,7 +502,8 @@ def score_submission_on_workload(workload: spec.Workload, all_metrics = [] for hi, hyperparameters in enumerate(tuning_search_space): # Generate a new seed from hardware sources of randomness for each trial. - rng_seed = struct.unpack('I', os.urandom(4))[0] + if not rng_seed: + rng_seed = struct.unpack('I', os.urandom(4))[0] logging.info('Using RNG seed %d', rng_seed) rng = prng.PRNGKey(rng_seed) # Because we initialize the PRNGKey with only a single 32 bit int, in the @@ -528,7 +535,9 @@ def score_submission_on_workload(workload: spec.Workload, data_dir, imagenet_v2_data_dir, init_optimizer_state, update_params, data_selection, - hyperparameters, rng, + hyperparameters, + rng_seed, + rng, profiler, max_global_steps, tuning_dir_name, @@ -545,7 +554,8 @@ def score_submission_on_workload(workload: spec.Workload, logging.info(f'Total number of evals: {num_evals}') logging.info('=' * 20) else: - rng_seed = struct.unpack('q', os.urandom(8))[0] + if not rng_seed: + rng_seed = struct.unpack('q', os.urandom(8))[0] rng = prng.PRNGKey(rng_seed) # If the submission is responsible for tuning itself, we only need to run it # once and return the total time. @@ -554,7 +564,7 @@ def score_submission_on_workload(workload: spec.Workload, workload, global_batch_size, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, update_params, data_selection, - None, rng, profiler, max_global_steps, log_dir, + None, rng_seed, rng, profiler, max_global_steps, log_dir, save_checkpoints=save_checkpoints) return score @@ -610,7 +620,8 @@ def main(_): tuning_search_space=FLAGS.tuning_search_space, num_tuning_trials=FLAGS.num_tuning_trials, log_dir=logging_dir_path, - save_checkpoints=FLAGS.save_checkpoints) + save_checkpoints=FLAGS.save_checkpoints, + rng_seed=FLAGS.rng_seed) logging.info(f'Final {FLAGS.workload} score: {score}') if FLAGS.profile: