Skip to content

Commit

Permalink
initpy and test_transformers_model_export
Browse files Browse the repository at this point in the history
Signed-off-by: swathi <konakanchi.swathi@gmail.com>
  • Loading branch information
KonakanchiSwathi committed Nov 29, 2023
1 parent 4b8bb73 commit b9b4b2f
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 25 deletions.
121 changes: 104 additions & 17 deletions mlflow/transformers/__init__.py
Expand Up @@ -1337,7 +1337,6 @@ def _should_add_pyfunc_to_model(pipeline) -> bool:
"DocumentQuestionAnsweringPipeline",
"ImageToTextPipeline",
"VisualQuestionAnsweringPipeline",
"ImageClassificationPipeline",
"ImageSegmentationPipeline",
"DepthEstimationPipeline",
"ObjectDetectionPipeline",
Expand All @@ -1347,11 +1346,6 @@ def _should_add_pyfunc_to_model(pipeline) -> bool:
"ZeroShotAudioClassificationPipeline",
]

impermissible_attrs = {"image_processor"}

for attr in impermissible_attrs:
if getattr(pipeline, attr, None) is not None:
return False
for model_type in exclusion_model_types:
if hasattr(transformers, model_type):
if isinstance(pipeline.model, getattr(transformers, model_type)):
Expand Down Expand Up @@ -1419,7 +1413,13 @@ def _get_default_pipeline_signature(pipeline, example=None, model_config=None) -
return ModelSignature(
inputs=Schema([ColSpec("string")]), outputs=Schema([ColSpec("string")])
)
elif isinstance(pipeline, transformers.TextClassificationPipeline):
elif isinstance(
pipeline,
(
transformers.TextClassificationPipeline,
transformers.ImageClassificationPipeline,
),
):
return ModelSignature(
inputs=Schema([ColSpec("string")]),
outputs=Schema([ColSpec("string", name="label"), ColSpec("double", name="score")]),
Expand Down Expand Up @@ -1757,7 +1757,7 @@ def predict(self, data, params: Optional[Dict[str, Any]] = None):
raise MlflowException(
"Input data must be either a pandas.DataFrame, a string, bytes, List[str], "
"List[Dict[str, str]], List[Dict[str, Union[str, List[str]]]], "
"or Dict[str, Union[str, List[str]]].",
"or Dict[str, Union[str, List[str]]]",
error_code=INVALID_PARAMETER_VALUE,
)
input_data = self._parse_raw_pipeline_input(input_data)
Expand Down Expand Up @@ -1799,6 +1799,9 @@ def _predict(self, data):
output_key = "token_str"
elif isinstance(self.pipeline, transformers.TextClassificationPipeline):
output_key = "label"
elif isinstance(self.pipeline, transformers.ImageClassificationPipeline):
data = self._convert_image_input(data)
output_key = "label"
elif isinstance(self.pipeline, transformers.ZeroShotClassificationPipeline):
output_key = "labels"
data = self._parse_json_encoded_list(data, "candidate_labels")
Expand Down Expand Up @@ -1877,7 +1880,11 @@ def _predict(self, data):
output = json.dumps(raw_output)
elif isinstance(
self.pipeline,
(transformers.AudioClassificationPipeline, transformers.TextClassificationPipeline),
(
transformers.AudioClassificationPipeline,
transformers.TextClassificationPipeline,
transformers.ImageClassificationPipeline,
),
):
return pd.DataFrame(raw_output)
else:
Expand Down Expand Up @@ -2564,6 +2571,60 @@ def _convert_cast_lists_from_np_back_to_list(data):
parsed_data.append(entry)
return parsed_data

def _convert_image_input(self, data):
"""
Conversion utility for decoding the base64 encoded bytes data of a raw image file when
parsed through model serving, if applicable. Direct usage of the pyfunc implementation
outside of model serving will treat this utility as a noop.
For reference, the expected encoding for input to Model Serving will be:
import requests
import base64
response = requests.get("https://www.my.images/a/sound/file.jpg")
encoded_image = base64.b64encode(response.content).decode("utf-8")
inference_data = json.dumps({"inputs": [encoded_image]})
or
inference_df = pd.DataFrame(
pd.Series([encoded_image], name="image_file")
)
split_dict = {"dataframe_split": inference_df.to_dict(orient="split")}
split_json = json.dumps(split_dict)
or
records_dict = {"dataframe_records": inference_df.to_dict(orient="records")}
records_json = json.dumps(records_dict)
This utility will convert this JSON encoded, base64 encoded text back into bytes for
input into the Image pipelines for inference.
"""

def is_base64_image(s):
try:
return base64.b64encode(base64.b64decode(s)).decode("utf-8") == s
except binascii.Error:
return False

if isinstance(data, list) and all(isinstance(element, dict) for element in data):
lst_data = []
for item in data:
data_ele = next(iter(item.values()))
if isinstance(data_ele, str):
# base64 encoded image comes as string
if not is_base64_image(data_ele):
self._validate_str_input_uri_or_file(data_ele)
lst_data.append(data_ele)
return lst_data
elif isinstance(data, str):
if not is_base64_image(data):
self._validate_str_input_uri_or_file(data)
return data

def _convert_audio_input(self, data):
"""
Conversion utility for decoding the base64 encoded bytes data of a raw soundfile when
Expand Down Expand Up @@ -2652,7 +2713,8 @@ def decode_audio(encoded):
@staticmethod
def _validate_str_input_uri_or_file(input_str):
"""
Validation of blob references to audio files, if a string is input to the ``predict``
Validation of blob references for audio or image transformers pipelines;
if a string is input to the ``predict``
method, perform validation of the string contents by checking for a valid uri or
filesystem reference instead of surfacing the cryptic stack trace that is otherwise raised
for an invalid uri input.
Expand All @@ -2665,14 +2727,39 @@ def is_uri(s):
except ValueError:
return False

valid_uri = os.path.isfile(input_str) or is_uri(input_str)
def validate_nested_list(lst):
for item in lst:
if isinstance(item, list):
validate_nested_list(item)
else:
validate_single_input(key, item)

if not valid_uri:
raise MlflowException(
"An invalid string input was provided. String inputs to "
"audio files must be either a file location or a uri.",
error_code=BAD_REQUEST,
)
def validate_input(key, value):
# Use pathlib to handle file paths
# input_path = os.Path(value)

# Check if it's a valid file path or URI
# valid_input = input_path.is_file() or is_uri(value)
valid_uri = os.path.isfile(input_str) or is_uri(input_str)

if not valid_uri:
raise MlflowException(
"An invalid string input was provided. String inputs to "
"audio or image files must be either a file location or a uri.",
error_code=BAD_REQUEST,
)

def validate_single_input(key, value):
if isinstance(value, list):
validate_nested_list(value)
else:
validate_input(key, value)

if isinstance(input_str, dict):
for key, value in input_str.items():
validate_input(key, value)
else:
validate_input(None, input_str)


@experimental
Expand Down
153 changes: 145 additions & 8 deletions tests/transformers/test_transformers_model_export.py
Expand Up @@ -80,6 +80,7 @@
# runners#supported-runners-and-hardware-resources for instance specs.
RUNNING_IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
GITHUB_ACTIONS_SKIP_REASON = "Test consumes too much memory"
image_url = "https://raw.githubusercontent.com/mlflow/mlflow/master/tests/datasets/cat.png"

# Test that can only be run locally:
# - Summarization pipeline tests
Expand Down Expand Up @@ -479,7 +480,7 @@ def test_instance_extraction(small_qa_pipeline):
("small_qa_pipeline", True),
("small_seq2seq_pipeline", True),
("small_multi_modal_pipeline", False),
("small_vision_model", False),
("small_vision_model", True),
],
)
def test_pipeline_eligibility_for_pyfunc_registration(model, result, request):
Expand Down Expand Up @@ -625,6 +626,31 @@ def test_vision_model_save_pipeline_with_defaults(small_vision_model, model_path
assert flavor_config["source_model_name"] == "google/mobilenet_v2_1.0_224"


def test_vision_model_save__model_for_task_and_card_inference(small_vision_model, model_path):
mlflow.transformers.save_model(transformers_model=small_vision_model, path=model_path)
# validate inferred pip requirements
with model_path.joinpath("requirements.txt").open() as file:
requirements = file.read()
reqs = {req.split("==")[0] for req in requirements.split("\n")}
expected_requirements = {"torch", "torchvision", "transformers"}
assert reqs.intersection(expected_requirements) == expected_requirements
# validate inferred model card data
card_data = yaml.safe_load(model_path.joinpath("model_card_data.yaml").read_bytes())
assert card_data["tags"] == ["vision", "image-classification"]
# Validate inferred model card text
with model_path.joinpath("model_card.md").open(encoding="utf-8") as file:
card_text = file.read()
assert len(card_text) > 0

# Validate the MLModel file
mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes())
flavor_config = mlmodel["flavors"]["transformers"]
assert flavor_config["instance_type"] == "ImageClassificationPipeline"
assert flavor_config["pipeline_model_type"] == "MobileNetV2ForImageClassification"
assert flavor_config["task"] == "image-classification"
assert flavor_config["source_model_name"] == "google/mobilenet_v2_1.0_224"


def test_qa_model_save_model_for_task_and_card_inference(small_seq2seq_pipeline, model_path):
mlflow.transformers.save_model(
transformers_model={
Expand Down Expand Up @@ -971,11 +997,6 @@ def test_transformers_log_model_with_no_registered_model_name(small_vision_model
conda_env=str(conda_env),
)
mlflow.tracking._model_registry.fluent._register_model.assert_not_called()
model_uri = f"runs:/{mlflow.active_run().info.run_id}/{artifact_path}"
model_path = pathlib.Path(_download_artifact_from_uri(artifact_uri=model_uri))
model_config = Model.load(str(model_path.joinpath("MLmodel")))
# Vision models can't be loaded as pyfunc currently.
assert pyfunc.FLAVOR_NAME not in model_config.flavors


def test_transformers_save_persists_requirements_in_mlflow_directory(
Expand Down Expand Up @@ -1341,6 +1362,49 @@ def test_qa_pipeline_pyfunc_load_and_infer(small_qa_pipeline, model_path, infere
assert all(isinstance(element, str) for element in inference)


def read_image(filename):
image_path = os.path.join(pathlib.Path(__file__).parent.parent, "datasets", filename)
with open(image_path, "rb") as f:
return f.read()


def is_base64_image(s):
try:
return base64.b64encode(base64.b64decode(s)).decode("utf-8") == s
except Exception:
return False


@pytest.mark.parametrize(
"inference_payload",
[
image_url,
os.path.join(pathlib.Path(__file__).parent.parent, "datasets", "cat.png"),
"base64",
],
)
def test_vision_pipeline_pyfunc_load_and_infer(small_vision_model, model_path, inference_payload):
if inference_payload == "base64":
if Version(transformers.__version__) > Version("4.28") or Version(
transformers.__version__
) < Version("4.33"):
return
inference_payload = base64.b64encode(read_image("cat_image.jpg")).decode("utf-8")
signature = infer_signature(
inference_payload,
mlflow.transformers.generate_signature_output(small_vision_model, inference_payload),
)
mlflow.transformers.save_model(
transformers_model=small_vision_model,
path=model_path,
signature=signature,
)
pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
predictions = pyfunc_loaded.predict(inference_payload)
assert len(predictions) != 0


@pytest.mark.skipif(RUNNING_IN_GITHUB_ACTIONS, reason=GITHUB_ACTIONS_SKIP_REASON)
@pytest.mark.parametrize(
("data", "result"),
[
Expand Down Expand Up @@ -2062,6 +2126,43 @@ def test_qa_pipeline_pyfunc_predict(small_qa_pipeline):
assert values.to_dict(orient="records") == [{0: "Run"}]


@pytest.mark.parametrize(
"inference_payload",
[
[os.path.join(pathlib.Path(__file__).parent.parent, "datasets", "cat.png")],
[image_url, image_url],
"base64",
],
)
def test_vision_pipeline_pyfunc_predict(small_vision_model, inference_payload):
if not isinstance(inference_payload, list) and inference_payload == "base64":
if transformers.__version__ > "4.28" or transformers.__version__ < "4.33":
return
inference_payload = [
base64.b64encode(read_image("cat_image.jpg")).decode("utf-8"),
]
artifact_path = "image_classification_model"

# Log the image classification model
with mlflow.start_run():
mlflow.transformers.log_model(
transformers_model=small_vision_model,
artifact_path=artifact_path,
)
model_uri = mlflow.get_artifact_uri(artifact_path)
inference_payload = json.dumps({"inputs": inference_payload})
response = pyfunc_serve_and_score_model(
model_uri,
data=inference_payload,
content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
extra_args=["--env-manager", "local"],
)

predictions = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()

assert len(predictions) != 0


def test_classifier_pipeline_pyfunc_predict(text_classification_pipeline):
artifact_path = "text_classifier_model"
with mlflow.start_run():
Expand Down Expand Up @@ -3464,6 +3565,44 @@ def test_save_model_card_with_non_utf_characters(tmp_path, model_name):
assert data == card_data.data.to_dict()


def test_vision_pipeline_pyfunc_predict_with_kwargs(small_vision_model):
artifact_path = "image_classification_model"

image_file_paths = image_url
parameters = {
"top_k": 2,
}
inference_payload = json.dumps(
{
"inputs": [image_file_paths],
"params": parameters,
}
)

with mlflow.start_run():
mlflow.transformers.log_model(
transformers_model=small_vision_model,
artifact_path=artifact_path,
signature=infer_signature(
image_file_paths,
mlflow.transformers.generate_signature_output(small_vision_model, image_file_paths),
params=parameters,
),
)
model_uri = mlflow.get_artifact_uri(artifact_path)

response = pyfunc_serve_and_score_model(
model_uri,
data=inference_payload,
content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
extra_args=["--env-manager", "local"],
)

predictions = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
assert len(predictions) != 0
assert len(predictions.iloc[0]) == parameters["top_k"]


def test_qa_pipeline_pyfunc_predict_with_kwargs(small_qa_pipeline):
artifact_path = "qa_model"
data = {
Expand Down Expand Up @@ -3853,9 +3992,7 @@ def test_basic_model_with_accelerate_homogeneous_mapping_works(tmp_path):
mlflow.transformers.save_model(transformers_model=pipeline, path=str(tmp_path / "model"))

loaded = mlflow.transformers.load_model(str(tmp_path / "model"))

text = "Apples are delicious"

assert loaded(text) == pipeline(text)


Expand Down

0 comments on commit b9b4b2f

Please sign in to comment.