diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 45628a24..45959169 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -69,7 +69,9 @@ USE_MICRO_MODELS = os.environ.get("FMS_TEST_SHAPES_USE_MICRO_MODELS", "1") == "1" USE_DISTRIBUTED = os.environ.get("FMS_TEST_SHAPES_DISTRIBUTED", "0") == "1" TIMING = os.environ.get("TIMING", "") - +CUMULATIVE_TEST_TOKENS_PER_SEQUENCE = int( + os.environ.get("FMS_TEST_SHAPES_CUMULATIVE_TEST_TOKENS_PER_SEQUENCE", "1024") +) ATTN_TYPE = os.environ.get("FMS_TEST_SHAPES_ATTN_TYPE", "sdpa") attention_map = { "sdpa": "sdpa_causal", @@ -608,7 +610,7 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor): ) return (cross_entropy, diff) - iters = 1024 // max_new_tokens + iters = int(CUMULATIVE_TEST_TOKENS_PER_SEQUENCE) // max_new_tokens ce_fail_responses_list = [] diff_fail_responses_list = [] total_tokens = 0