Skip to content

Commit

Permalink
re-enable some tests and change models to be smaller
Browse files Browse the repository at this point in the history
Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
  • Loading branch information
serena-ruan committed Nov 9, 2023
1 parent 4f70869 commit 7e4dd09
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions tests/transformers/test_transformers_model_export.py
Expand Up @@ -253,7 +253,7 @@ def fill_mask_pipeline():
@flaky()
def text2text_generation_pipeline():
task = "text2text-generation"
architecture = "mrm8488/t5-base-finetuned-common_gen"
architecture = "mrm8488/t5-small-finetuned-common_gen"
model = transformers.T5ForConditionalGeneration.from_pretrained(architecture)
tokenizer = transformers.T5TokenizerFast.from_pretrained(architecture)

Expand Down Expand Up @@ -293,7 +293,7 @@ def translation_pipeline():
@flaky()
def summarizer_pipeline():
task = "summarization"
architecture = "philschmid/distilbart-cnn-12-6-samsum"
architecture = "sshleifer/distilbart-cnn-6-6"
model = transformers.BartForConditionalGeneration.from_pretrained(architecture)
tokenizer = transformers.AutoTokenizer.from_pretrained(architecture)
return transformers.pipeline(
Expand Down Expand Up @@ -363,7 +363,7 @@ def ner_pipeline_aggregation():
@pytest.fixture
@flaky()
def conversational_pipeline():
return transformers.pipeline(model="microsoft/DialoGPT-medium")
return transformers.pipeline(model="AVeryRealHuman/DialoGPT-small-TonyStark")


@pytest.fixture
Expand Down Expand Up @@ -1341,18 +1341,17 @@ def test_qa_pipeline_pyfunc_load_and_infer(small_qa_pipeline, model_path, infere
assert all(isinstance(element, str) for element in inference)


@pytest.mark.skipif(RUNNING_IN_GITHUB_ACTIONS, reason=GITHUB_ACTIONS_SKIP_REASON)
@pytest.mark.parametrize(
("data", "result"),
[
("muppet keyboard type", ["A muppet is typing on a keyboard."]),
("muppet keyboard type", ["A man is typing a muppet on a keyboard."]),
(
["pencil draw paper", "pie apple eat"],
# NB: The result of this test case, without inference config overrides is:
# ["A man drawing on paper with pencil", "A man eating a pie with applies"]
# The inference config override forces additional insertion of more grammatically
# correct responses to validate that the inference config is being applied.
["A man is drawing on paper with a pencil.", "A man is eating a pie with apples."],
["A man draws a pencil on a paper.", "A man eats a pie of apples."],
),
],
)
Expand Down Expand Up @@ -1510,7 +1509,6 @@ def test_text2text_generation_pipeline_with_params_with_errors(
pyfunc_loaded.predict(data, {"top_k": "2"})


@pytest.mark.skipif(RUNNING_IN_GITHUB_ACTIONS, reason=GITHUB_ACTIONS_SKIP_REASON)
def test_text2text_generation_pipeline_with_inferred_schema(text2text_generation_pipeline):
with mlflow.start_run():
model_info = mlflow.transformers.log_model(
Expand All @@ -1519,7 +1517,7 @@ def test_text2text_generation_pipeline_with_inferred_schema(text2text_generation
pyfunc_loaded = mlflow.pyfunc.load_model(model_info.model_uri)

assert pyfunc_loaded.predict("muppet board nails hammer") == [
"A muppet is hammering nails on a board."
"A hammer with a muppet and nails on a board."
]


Expand All @@ -1530,7 +1528,6 @@ def test_text2text_generation_pipeline_with_inferred_schema(text2text_generation
([{"answer": ["42"], "context": "life"}, {"unmatched": "keys", "cause": "failure"}]),
],
)
@pytest.mark.skipif(RUNNING_IN_GITHUB_ACTIONS, reason=GITHUB_ACTIONS_SKIP_REASON)
def test_invalid_input_to_text2text_pipeline(text2text_generation_pipeline, invalid_data):
# Adding this validation test due to the fact that we're constructing the input to the
# Pipeline. The Pipeline requires a format of a pseudo-dict-like string. An example of
Expand All @@ -1549,7 +1546,6 @@ def test_invalid_input_to_text2text_pipeline(text2text_generation_pipeline, inva
@pytest.mark.parametrize(
"data", ["Generative models are", (["Generative models are", "Computers are"])]
)
@pytest.mark.skipif(RUNNING_IN_GITHUB_ACTIONS, reason=GITHUB_ACTIONS_SKIP_REASON)
def test_text_generation_pipeline(text_generation_pipeline, model_path, data):
signature = infer_signature(
data, mlflow.transformers.generate_signature_output(text_generation_pipeline, data)
Expand Down Expand Up @@ -1598,7 +1594,6 @@ def test_text_generation_pipeline(text_generation_pipeline, model_path, data):
(["tell me a story", {"of": "a properly configured pipeline input"}]),
],
)
@pytest.mark.skipif(RUNNING_IN_GITHUB_ACTIONS, reason=GITHUB_ACTIONS_SKIP_REASON)
def test_invalid_input_to_text_generation_pipeline(text_generation_pipeline, invalid_data):
if isinstance(invalid_data, list):
match = "If supplying a list, all values must be of string type"
Expand Down Expand Up @@ -1798,7 +1793,6 @@ def test_translation_pipeline(translation_pipeline, model_path, data, result):
["Baking cookies is quite easy", "Writing unittests is good for"],
],
)
@pytest.mark.skipif(RUNNING_IN_GITHUB_ACTIONS, reason=GITHUB_ACTIONS_SKIP_REASON)
def test_summarization_pipeline(summarizer_pipeline, model_path, data):
model_config = {
"top_k": 2,
Expand Down Expand Up @@ -1923,7 +1917,6 @@ def test_ner_pipeline(pipeline_name, model_path, data, result, request):
assert pd_inference == result


@pytest.mark.skipif(RUNNING_IN_GITHUB_ACTIONS, reason=GITHUB_ACTIONS_SKIP_REASON)
def test_conversational_pipeline(conversational_pipeline, model_path):
signature = infer_signature(
"Hi there!",
Expand All @@ -1935,21 +1928,21 @@ def test_conversational_pipeline(conversational_pipeline, model_path):

first_response = loaded_pyfunc.predict("What is the best way to get to Antarctica?")

assert first_response == "I think you can get there by boat."
assert first_response == "The best way would be to go to space."

second_response = loaded_pyfunc.predict("What kind of boat should I use?")

assert second_response == "A boat that can go to Antarctica."
assert second_response == "The best way to get to space would be to reach out and touch it."

# Test that a new loaded instance has no context.
loaded_again_pyfunc = mlflow.pyfunc.load_model(model_path)
third_response = loaded_again_pyfunc.predict("What kind of boat should I use?")

assert third_response == "A boat that can't sink."
assert third_response == "The one with the guns."

fourth_response = loaded_again_pyfunc.predict("Can I use it to go to the moon?")

assert fourth_response == "Only if you have a boat that can't sink."
assert fourth_response == "Sure."


@pytest.mark.parametrize(
Expand Down Expand Up @@ -2902,7 +2895,6 @@ def test_instructional_pipeline_with_prompt_in_output(model_path):
),
],
)
@pytest.mark.skipif(RUNNING_IN_GITHUB_ACTIONS, reason=GITHUB_ACTIONS_SKIP_REASON)
@pytest.mark.skipcacheclean
def test_signature_inference(pipeline_name, data, result, request):
pipeline = request.getfixturevalue(pipeline_name)
Expand Down

0 comments on commit 7e4dd09

Please sign in to comment.