Skip to content

Commit

Permalink
initpy and test_transformers_model_export (#10538)
Browse files Browse the repository at this point in the history
Signed-off-by: swathi <konakanchi.swathi@gmail.com>
Signed-off-by: Hankyeol Kyung <kghnkl0103@gmail.com>
Signed-off-by: Prithvi Kannan <prithvi.kannan@databricks.com>
Signed-off-by: Jerry Liang <jerry.liang@databricks.com>
Signed-off-by: Jerry Liang <66143562+jerrylian-db@users.noreply.github.com>
Signed-off-by: serena-ruan <serena.ruan@ip-10-110-25-32.us-west-2.compute.internal>
Signed-off-by: Serena Ruan <serena.rxy@gmail.com>
Signed-off-by: Gabriel Fu <hfu.gabriel@gmail.com>
Signed-off-by: Madhu <madhukesav02@gmail.com>
Signed-off-by: Konakanchi Swathi <98085410+KonakanchiSwathi@users.noreply.github.com>
Co-authored-by: Marcus Kyung <kghnkl0103@gmail.com>
Co-authored-by: Prithvi Kannan <46332835+prithvikannan@users.noreply.github.com>
Co-authored-by: Jerry Liang <66143562+jerrylian-db@users.noreply.github.com>
Co-authored-by: WeichenXu <weichen.xu@databricks.com>
Co-authored-by: Siddharth Murching <sid.murching@databricks.com>
Co-authored-by: Serena Ruan <82044803+serena-ruan@users.noreply.github.com>
Co-authored-by: serena-ruan <serena.ruan@ip-10-110-25-32.us-west-2.compute.internal>
Co-authored-by: Gabriel Fu <hfu.gabriel@gmail.com>
Co-authored-by: madhumaddi <madhumaddi@microsoft.com>
Co-authored-by: Madhu <madhukesav02@gmail.com>
Co-authored-by: Ben Wilson <39283302+BenWilson2@users.noreply.github.com>
  • Loading branch information
12 people committed Dec 21, 2023
1 parent cf1cf51 commit b929a3e
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 20 deletions.
86 changes: 77 additions & 9 deletions mlflow/transformers/__init__.py
Expand Up @@ -1344,7 +1344,6 @@ def _should_add_pyfunc_to_model(pipeline) -> bool:
"DocumentQuestionAnsweringPipeline",
"ImageToTextPipeline",
"VisualQuestionAnsweringPipeline",
"ImageClassificationPipeline",
"ImageSegmentationPipeline",
"DepthEstimationPipeline",
"ObjectDetectionPipeline",
Expand All @@ -1354,11 +1353,6 @@ def _should_add_pyfunc_to_model(pipeline) -> bool:
"ZeroShotAudioClassificationPipeline",
]

impermissible_attrs = {"image_processor"}

for attr in impermissible_attrs:
if getattr(pipeline, attr, None) is not None:
return False
for model_type in exclusion_model_types:
if hasattr(transformers, model_type):
if isinstance(pipeline.model, getattr(transformers, model_type)):
Expand Down Expand Up @@ -1427,7 +1421,13 @@ def _get_default_pipeline_signature(pipeline, example=None, model_config=None) -
return ModelSignature(
inputs=Schema([ColSpec("string")]), outputs=Schema([ColSpec("string")])
)
elif isinstance(pipeline, transformers.TextClassificationPipeline):
elif isinstance(
pipeline,
(
transformers.TextClassificationPipeline,
transformers.ImageClassificationPipeline,
),
):
return ModelSignature(
inputs=Schema([ColSpec("string")]),
outputs=Schema([ColSpec("string", name="label"), ColSpec("double", name="score")]),
Expand Down Expand Up @@ -1815,6 +1815,9 @@ def _predict(self, data):
output_key = "token_str"
elif isinstance(self.pipeline, transformers.TextClassificationPipeline):
output_key = "label"
elif isinstance(self.pipeline, transformers.ImageClassificationPipeline):
data = self._convert_image_input(data)
output_key = "label"
elif isinstance(self.pipeline, transformers.ZeroShotClassificationPipeline):
output_key = "labels"
data = self._parse_json_encoded_list(data, "candidate_labels")
Expand Down Expand Up @@ -1893,7 +1896,11 @@ def _predict(self, data):
output = json.dumps(raw_output)
elif isinstance(
self.pipeline,
(transformers.AudioClassificationPipeline, transformers.TextClassificationPipeline),
(
transformers.AudioClassificationPipeline,
transformers.TextClassificationPipeline,
transformers.ImageClassificationPipeline,
),
):
return pd.DataFrame(raw_output)
else:
Expand Down Expand Up @@ -2580,6 +2587,65 @@ def _convert_cast_lists_from_np_back_to_list(data):
parsed_data.append(entry)
return parsed_data

@staticmethod
def is_base64_image(image):
"""Check whether input image is a base64 encoded"""

try:
return base64.b64encode(base64.b64decode(image)).decode("utf-8") == image
except binascii.Error:
return False

def _convert_image_input(self, input_data):
"""
Conversion utility for decoding the base64 encoded bytes data of a raw image file when
parsed through model serving, if applicable. Direct usage of the pyfunc implementation
outside of model serving will treat this utility as a noop.
For reference, the expected encoding for input to Model Serving will be:
import requests
import base64
response = requests.get("https://www.my.images/a/sound/file.jpg")
encoded_image = base64.b64encode(response.content).decode("utf-8")
inference_data = json.dumps({"inputs": [encoded_image]})
or
inference_df = pd.DataFrame(
pd.Series([encoded_image], name="image_file")
)
split_dict = {"dataframe_split": inference_df.to_dict(orient="split")}
split_json = json.dumps(split_dict)
or
records_dict = {"dataframe_records": inference_df.to_dict(orient="records")}
records_json = json.dumps(records_dict)
This utility will convert this JSON encoded, base64 encoded text back into bytes for
input into the Image pipelines for inference.
"""

def process_input_element(input_element):
input_value = next(iter(input_element.values()))
if isinstance(input_value, str) and not self.is_base64_image(input_value):
self._validate_str_input_uri_or_file(input_value)
return input_value

if isinstance(input_data, list) and all(
isinstance(element, dict) for element in input_data
):
# Use a list comprehension for readability
# the elimination of empty collection declarations
return [process_input_element(element) for element in input_data]
elif isinstance(input_data, str) and not self.is_base64_image(input_data):
self._validate_str_input_uri_or_file(input_data)

return input_data

def _convert_audio_input(self, data):
"""
Conversion utility for decoding the base64 encoded bytes data of a raw soundfile when
Expand Down Expand Up @@ -2672,7 +2738,8 @@ def decode_audio(encoded):
@staticmethod
def _validate_str_input_uri_or_file(input_str):
"""
Validation of blob references to audio files, if a string is input to the ``predict``
Validation of blob references to either audio or image files,
if a string is input to the ``predict``
method, perform validation of the string contents by checking for a valid uri or
filesystem reference instead of surfacing the cryptic stack trace that is otherwise raised
for an invalid uri input.
Expand All @@ -2694,6 +2761,7 @@ def is_uri(s):
data_str = f"Received (truncated): {input_str[:20]}..."
raise MlflowException(
"An invalid string input was provided. String inputs to "
"audio or image files must be either a file location or a uri."
f"audio files must be either a file location or a uri. {data_str}",
error_code=BAD_REQUEST,
)
Expand Down
173 changes: 162 additions & 11 deletions tests/transformers/test_transformers_model_export.py
Expand Up @@ -50,6 +50,7 @@
_record_pipeline_components,
_should_add_pyfunc_to_model,
_TransformersModel,
_TransformersWrapper,
_validate_transformers_task_type,
_write_card_data,
get_default_conda_env,
Expand Down Expand Up @@ -81,7 +82,8 @@
# runners#supported-runners-and-hardware-resources for instance specs.
RUNNING_IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
GITHUB_ACTIONS_SKIP_REASON = "Test consumes too much memory"

image_url = "https://raw.githubusercontent.com/mlflow/mlflow/master/tests/datasets/cat.png"
image_file_path = pathlib.Path(pathlib.Path(__file__).parent.parent, "datasets", "cat.png")
# Test that can only be run locally:
# - Summarization pipeline tests
# - TextClassifier pipeline tests
Expand Down Expand Up @@ -480,7 +482,7 @@ def test_instance_extraction(small_qa_pipeline):
("small_qa_pipeline", True),
("small_seq2seq_pipeline", True),
("small_multi_modal_pipeline", False),
("small_vision_model", False),
("small_vision_model", True),
],
)
def test_pipeline_eligibility_for_pyfunc_registration(model, result, request):
Expand Down Expand Up @@ -600,8 +602,7 @@ def test_model_card_acquisition_vision_model(small_vision_model):
def test_vision_model_save_pipeline_with_defaults(small_vision_model, model_path):
mlflow.transformers.save_model(transformers_model=small_vision_model, path=model_path)
# validate inferred pip requirements
with model_path.joinpath("requirements.txt").open() as file:
requirements = file.read()
requirements = model_path.joinpath("requirements.txt").read_text()
reqs = {req.split("==")[0] for req in requirements.split("\n")}
expected_requirements = {"torch", "torchvision", "transformers"}
assert reqs.intersection(expected_requirements) == expected_requirements
Expand All @@ -626,6 +627,29 @@ def test_vision_model_save_pipeline_with_defaults(small_vision_model, model_path
assert flavor_config["source_model_name"] == "google/mobilenet_v2_1.0_224"


def test_vision_model_save_model_for_task_and_card_inference(small_vision_model, model_path):
mlflow.transformers.save_model(transformers_model=small_vision_model, path=model_path)
# validate inferred pip requirements
requirements = model_path.joinpath("requirements.txt").read_text()
reqs = {req.split("==")[0] for req in requirements.split("\n")}
expected_requirements = {"torch", "torchvision", "transformers"}
assert reqs.intersection(expected_requirements) == expected_requirements
# validate inferred model card data
card_data = yaml.safe_load(model_path.joinpath("model_card_data.yaml").read_bytes())
assert card_data["tags"] == ["vision", "image-classification"]
# Validate inferred model card text
card_text = model_path.joinpath("model_card.md").read_text(encoding="utf-8")
assert len(card_text) > 0

# Validate the MLModel file
mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes())
flavor_config = mlmodel["flavors"]["transformers"]
assert flavor_config["instance_type"] == "ImageClassificationPipeline"
assert flavor_config["pipeline_model_type"] == "MobileNetV2ForImageClassification"
assert flavor_config["task"] == "image-classification"
assert flavor_config["source_model_name"] == "google/mobilenet_v2_1.0_224"


def test_qa_model_save_model_for_task_and_card_inference(small_seq2seq_pipeline, model_path):
mlflow.transformers.save_model(
transformers_model={
Expand Down Expand Up @@ -972,11 +996,6 @@ def test_transformers_log_model_with_no_registered_model_name(small_vision_model
conda_env=str(conda_env),
)
mlflow.tracking._model_registry.fluent._register_model.assert_not_called()
model_uri = f"runs:/{mlflow.active_run().info.run_id}/{artifact_path}"
model_path = pathlib.Path(_download_artifact_from_uri(artifact_uri=model_uri))
model_config = Model.load(str(model_path.joinpath("MLmodel")))
# Vision models can't be loaded as pyfunc currently.
assert pyfunc.FLAVOR_NAME not in model_config.flavors


def test_transformers_save_persists_requirements_in_mlflow_directory(
Expand Down Expand Up @@ -1343,6 +1362,40 @@ def test_qa_pipeline_pyfunc_load_and_infer(small_qa_pipeline, model_path, infere
assert all(isinstance(element, str) for element in inference)


@pytest.mark.parametrize(
"inference_payload",
[
image_url,
str(image_file_path),
pytest.param(
"base64",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("4.33"),
reason="base64 feature not present",
),
),
],
)
def test_vision_pipeline_pyfunc_load_and_infer(small_vision_model, model_path, inference_payload):
if inference_payload == "base64":
inference_payload = base64.b64encode(image_file_path.read_bytes()).decode("utf-8")
signature = infer_signature(
inference_payload,
mlflow.transformers.generate_signature_output(small_vision_model, inference_payload),
)
mlflow.transformers.save_model(
transformers_model=small_vision_model,
path=model_path,
signature=signature,
)
pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
predictions = pyfunc_loaded.predict(inference_payload)

transformers_loaded_model = mlflow.transformers.load_model(model_path)
expected_predictions = transformers_loaded_model.predict(inference_payload)
assert list(predictions.to_dict("records")[0].values()) == expected_predictions


@pytest.mark.parametrize(
("data", "result"),
[
Expand Down Expand Up @@ -2099,6 +2152,64 @@ def test_qa_pipeline_pyfunc_predict(small_qa_pipeline):
assert values.to_dict(orient="records") == [{0: "Run"}]


@pytest.mark.parametrize(
("input_image", "result"),
[
(str(image_file_path), False),
(image_url, False),
("base64", True),
("random string", False),
],
)
def test_vision_is_base64_image(input_image, result):
if input_image == "base64":
input_image = base64.b64encode(image_file_path.read_bytes()).decode("utf-8")
assert _TransformersWrapper.is_base64_image(input_image) == result


@pytest.mark.parametrize(
"inference_payload",
[
[str(image_file_path)],
[image_url],
pytest.param(
"base64",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("4.33"),
reason="base64 feature not present",
),
),
],
)
def test_vision_pipeline_pyfunc_predict(small_vision_model, inference_payload):
if inference_payload == "base64":
inference_payload = [
base64.b64encode(image_file_path.read_bytes()).decode("utf-8"),
]
artifact_path = "image_classification_model"

# Log the image classification model
with mlflow.start_run():
model_info = mlflow.transformers.log_model(
transformers_model=small_vision_model,
artifact_path=artifact_path,
)
pyfunc_inference_payload = json.dumps({"inputs": inference_payload})
response = pyfunc_serve_and_score_model(
model_info.model_uri,
data=pyfunc_inference_payload,
content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
extra_args=["--env-manager", "local"],
)

predictions = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()

transformers_loaded_model = mlflow.transformers.load_model(model_info.model_uri)
expected_predictions = transformers_loaded_model.predict(inference_payload)

assert [list(pred.values()) for pred in predictions.to_dict("records")] == expected_predictions


def test_classifier_pipeline_pyfunc_predict(text_classification_pipeline):
artifact_path = "text_classifier_model"
data = [
Expand Down Expand Up @@ -3529,6 +3640,48 @@ def test_save_model_card_with_non_utf_characters(tmp_path, model_name):
assert data == card_data.data.to_dict()


def test_vision_pipeline_pyfunc_predict_with_kwargs(small_vision_model):
artifact_path = "image_classification_model"

parameters = {
"top_k": 2,
}
inference_payload = json.dumps(
{
"inputs": [image_url],
"params": parameters,
}
)

with mlflow.start_run():
model_info = mlflow.transformers.log_model(
transformers_model=small_vision_model,
artifact_path=artifact_path,
signature=infer_signature(
image_url,
mlflow.transformers.generate_signature_output(small_vision_model, image_url),
params=parameters,
),
)
model_uri = model_info.model_uri
transformers_loaded_model = mlflow.transformers.load_model(model_uri)
expected_predictions = transformers_loaded_model.predict(image_url)

response = pyfunc_serve_and_score_model(
model_uri,
data=inference_payload,
content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
extra_args=["--env-manager", "local"],
)

predictions = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()

assert (
list(predictions.to_dict("records")[0].values())
== expected_predictions[: parameters["top_k"]]
)


def test_qa_pipeline_pyfunc_predict_with_kwargs(small_qa_pipeline):
artifact_path = "qa_model"
data = {
Expand Down Expand Up @@ -3935,9 +4088,7 @@ def test_basic_model_with_accelerate_homogeneous_mapping_works(tmp_path):
mlflow.transformers.save_model(transformers_model=pipeline, path=str(tmp_path / "model"))

loaded = mlflow.transformers.load_model(str(tmp_path / "model"))

text = "Apples are delicious"

assert loaded(text) == pipeline(text)


Expand Down

0 comments on commit b929a3e

Please sign in to comment.