Skip to content

Commit

Permalink
Add test that verifies that the generation config passed in at model.…
Browse files Browse the repository at this point in the history
…predict() is used correctly. (#3523)
  • Loading branch information
justinxzhao committed Aug 11, 2023
1 parent 7ecdfa5 commit 095a1ad
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion tests/integration_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

LOCAL_BACKEND = {"type": "local"}
TEST_MODEL_NAME = "hf-internal-testing/tiny-random-GPTJForCausalLM"
MAX_NEW_TOKENS_TEST_DEFAULT = 5

RAY_BACKEND = {
"type": "ray",
Expand All @@ -46,6 +47,11 @@
}


def get_num_non_empty_tokens(iterable):
"""Returns the number of non-empty tokens."""
return len(list(filter(bool, iterable)))


@pytest.fixture(scope="module")
def local_backend():
return LOCAL_BACKEND
Expand Down Expand Up @@ -79,7 +85,7 @@ def get_generation_config():
"top_p": 0.75,
"top_k": 40,
"num_beams": 4,
"max_new_tokens": 5,
"max_new_tokens": MAX_NEW_TOKENS_TEST_DEFAULT,
}


Expand Down Expand Up @@ -136,6 +142,25 @@ def test_llm_text_to_text(tmpdir, backend, ray_cluster_4cpu):
assert preds["Answer_probability"]
assert preds["Answer_response"]

# Check that in-line generation parameters are used. Original prediction uses max_new_tokens = 5.
assert get_num_non_empty_tokens(preds["Answer_predictions"][0]) <= MAX_NEW_TOKENS_TEST_DEFAULT
original_max_new_tokens = model.model.generation.max_new_tokens

# This prediction uses max_new_tokens = 2.
preds, _ = model.predict(
dataset=dataset_filename,
output_directory=str(tmpdir),
split="test",
generation_config={"min_new_tokens": 2, "max_new_tokens": 3},
)
preds = convert_preds(preds)
print(preds["Answer_predictions"][0])
num_non_empty_tokens = get_num_non_empty_tokens(preds["Answer_predictions"][0])
assert 2 <= num_non_empty_tokens <= 3

# Check that the state of the model is unchanged.
assert model.model.generation.max_new_tokens == original_max_new_tokens


@pytest.mark.llm
@pytest.mark.parametrize(
Expand Down

0 comments on commit 095a1ad

Please sign in to comment.