diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index 389922fd35e..9644a9e95a2 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -29,6 +29,7 @@ LOCAL_BACKEND = {"type": "local"} TEST_MODEL_NAME = "hf-internal-testing/tiny-random-GPTJForCausalLM" +MAX_NEW_TOKENS_TEST_DEFAULT = 5 RAY_BACKEND = { "type": "ray", @@ -46,6 +47,11 @@ } +def get_num_non_empty_tokens(iterable): + """Returns the number of non-empty tokens.""" + return len(list(filter(bool, iterable))) + + @pytest.fixture(scope="module") def local_backend(): return LOCAL_BACKEND @@ -79,7 +85,7 @@ def get_generation_config(): "top_p": 0.75, "top_k": 40, "num_beams": 4, - "max_new_tokens": 5, + "max_new_tokens": MAX_NEW_TOKENS_TEST_DEFAULT, } @@ -136,6 +142,25 @@ def test_llm_text_to_text(tmpdir, backend, ray_cluster_4cpu): assert preds["Answer_probability"] assert preds["Answer_response"] + # Check that in-line generation parameters are used. Original prediction uses max_new_tokens = 5. + assert get_num_non_empty_tokens(preds["Answer_predictions"][0]) <= MAX_NEW_TOKENS_TEST_DEFAULT + original_max_new_tokens = model.model.generation.max_new_tokens + + # This prediction uses max_new_tokens = 2. + preds, _ = model.predict( + dataset=dataset_filename, + output_directory=str(tmpdir), + split="test", + generation_config={"min_new_tokens": 2, "max_new_tokens": 3}, + ) + preds = convert_preds(preds) + print(preds["Answer_predictions"][0]) + num_non_empty_tokens = get_num_non_empty_tokens(preds["Answer_predictions"][0]) + assert 2 <= num_non_empty_tokens <= 3 + + # Check that the state of the model is unchanged. + assert model.model.generation.max_new_tokens == original_max_new_tokens + @pytest.mark.llm @pytest.mark.parametrize(