Skip to content

Commit

Permalink
update unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
faaany committed May 17, 2023
1 parent dd993b3 commit a27bdde
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions test/prompt/test_prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,19 +251,21 @@ def test_generation_kwargs_from_prompt_node_call():


@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["google/flan-t5-base", "gpt-3.5-turbo"], indirect=True)
@pytest.mark.parametrize("prompt_model", ["hf", "openai", "azure"], indirect=True)
def test_generation_kwargs_from_prompt_node_run(prompt_model):
skip_test_for_invalid_key(prompt_model)
the_question = "What does 42 mean?"
# test that generation_kwargs are passed to the underlying invocation layer
node = PromptNode(model_name_or_path=prompt_model)
with patch.object(node.prompt_model.model_invocation_layer.pipe, "run_single", MagicMock()) as mock_call:
node = PromptNode(prompt_model)
with patch.object(node.prompt_model.model_invocation_layer, "invoke", MagicMock()) as mock_call:
node.run(query=the_question, prompt_template="{query}", generation_kwargs={"do_sample": True, "test": True})

mock_call.assert_called_with(
the_question,
{},
{"do_sample": True, "test": True, "num_return_sequences": 1, "num_beams": 1, "max_length": 100},
{},
prompt=the_question,
stop_words=None,
top_k=1,
query=the_question,
generation_kwargs={"do_sample": True, "test": True},
)


Expand Down

0 comments on commit a27bdde

Please sign in to comment.