diff --git a/mlperf_logging/compliance_checker/training_5.1.0/closed_llama31_8b.yaml b/mlperf_logging/compliance_checker/training_5.1.0/closed_llama31_8b.yaml index b6335dc..2aba722 100644 --- a/mlperf_logging/compliance_checker/training_5.1.0/closed_llama31_8b.yaml +++ b/mlperf_logging/compliance_checker/training_5.1.0/closed_llama31_8b.yaml @@ -25,7 +25,7 @@ - KEY: NAME: opt_learning_rate_decay_steps REQ: EXACTLY_ONE - CHECK: " v['value'] * s['global_batch_size'] == 1.2e6 " + CHECK: " v['value'] == 1200000 " - KEY: NAME: opt_learning_rate_warmup_steps @@ -79,7 +79,7 @@ ATLEAST_ONE_CHECK: "(v['value'] <= 3.3) and v['value'] > 0.0" - KEY: - NAME: MAX_STEPS + NAME: max_steps REQ: EXACTLY_ONE CHECK: " v['value'] == 1200000 " diff --git a/mlperf_logging/mllog/constants.py b/mlperf_logging/mllog/constants.py index 880a814..57972a6 100644 --- a/mlperf_logging/mllog/constants.py +++ b/mlperf_logging/mllog/constants.py @@ -117,6 +117,7 @@ LARS_OPT_WEIGHT_DECAY = "lars_opt_weight_decay" MAX_IMAGE_SIZE = "max_image_size" MAX_SAMPLES = "max_samples" +MAX_STEPS = "max_steps" MAX_SEQUENCE_LENGTH = "max_sequence_length" MIN_IMAGE_SIZE = "min_image_size" MODEL_BN_SPAN = "model_bn_span"