Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initpy and test_transformers_model_export #10538

Merged
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
cfebd44
initpy and test_transformers_model_export
KonakanchiSwathi Nov 29, 2023
1b899c2
Update test_transformers_model_export.py
KonakanchiSwathi Dec 6, 2023
82d011d
Signed-off-by:swathi <konakanchi.swathi@gmail.com>
KonakanchiSwathi Dec 6, 2023
97ae2f4
Signed-off-by:swathi <konakanchi.swathi@gmail.com>
KonakanchiSwathi Dec 6, 2023
515c9f2
Merge branch 'AddImageclassification_newbranch' of https://github.com…
KonakanchiSwathi Dec 6, 2023
fc6ce63
Merge branch 'AddImageclassification_newbranch' of https://github.com…
KonakanchiSwathi Dec 6, 2023
155fd99
Merge branch 'AddImageclassification_newbranch' of https://github.com…
KonakanchiSwathi Dec 6, 2023
b8c51c2
Signed-off-by: swathi <konakanchi.swathi@gmail.com>
KonakanchiSwathi Dec 6, 2023
9710901
Signed-off-by:swathi <konakanchi.swathi@gmail.com>
KonakanchiSwathi Dec 6, 2023
167d836
Merge branch 'AddImageclassification_newbranch' of https://github.com…
KonakanchiSwathi Dec 6, 2023
37a00b8
fix: abstract classes use metaclass=ABCMeta (#10509)
keenranger Nov 28, 2023
63de361
Use Azure OpenAI in LLM RAG eval notebook (#10527)
prithvikannan Nov 28, 2023
a3d807f
Deprecate model registry stages in docs (#10480)
jerrylian-db Nov 28, 2023
b8b55d7
Fix for keras 3.0 (#10485)
serena-ruan Nov 29, 2023
d745eee
Refactor langchain (#10532)
serena-ruan Nov 29, 2023
e486eb9
Threaded proxy multipart upload (#10534)
gabrielfu Nov 29, 2023
67db66b
initpy and test_transformers_model_export
KonakanchiSwathi Nov 29, 2023
a104e85
Signed-off-by: swathi <konakanchi.swathi@gmail.com>
KonakanchiSwathi Dec 6, 2023
0e06964
Signed-off-by:swathi <konakanchi.swathi@gmail.com>
KonakanchiSwathi Dec 6, 2023
a546040
Update test_transformers_model_export.py
KonakanchiSwathi Dec 6, 2023
3b66134
Signed-off-by:swathi <konakanchi.swathi@gmail.com>
KonakanchiSwathi Dec 6, 2023
fc6f28b
Signed-off-by:swathi <konakanchi.swathi@gmail.com>
KonakanchiSwathi Dec 6, 2023
fdcefea
Merge branch 'AddImageclassification_newbranch' of https://github.com…
KonakanchiSwathi Dec 6, 2023
37c841b
Signed-off-by:swathi <konakanchi.swathi@gmail.com>
KonakanchiSwathi Dec 7, 2023
4c95261
Merge branch 'master' of https://github.com/KonakanchiSwathi/mlflow i…
MadhuM02 Dec 8, 2023
95ed7b7
addr commnets
Madhu0205 Dec 8, 2023
603a4d4
Merge branch 'mlflow:master' into AddImageclassification_newbranch
KonakanchiSwathi Dec 8, 2023
f3bbe7c
Update __init__.py
KonakanchiSwathi Dec 11, 2023
abd8c9f
Signed-off-by: swathi <konakanchi.swathi@gmail.com>
KonakanchiSwathi Dec 11, 2023
98cff65
Signed-off-by: swathi <konakanchi.swathi@gmail.com>
KonakanchiSwathi Dec 11, 2023
2b84af6
Signed-off-by: swathi <konakanchi.swathi@gmail.com>
KonakanchiSwathi Dec 11, 2023
5adc77d
fix uts
Madhu0205 Dec 12, 2023
d0db696
Merge branch 'mlflow:master' into AddImageclassification_newbranch
KonakanchiSwathi Dec 12, 2023
e5e89bd
update
Madhu0205 Dec 15, 2023
7c555d3
Merge branch 'AddImageclassification_newbranch' of https://github.com…
Madhu0205 Dec 15, 2023
ab9fee7
Signed-off-by: swathi <konakanchi.swathi@gmail.com>
KonakanchiSwathi Dec 15, 2023
4dbbaf8
Signed-off-by: swathi <konakanchi.swathi@gmail.com>
KonakanchiSwathi Dec 15, 2023
a309e45
Update tests/transformers/test_transformers_model_export.py
KonakanchiSwathi Dec 21, 2023
be250a6
Merge branch 'master' into AddImageclassification_newbranch
KonakanchiSwathi Dec 21, 2023
a44da12
Added model_uri=model_info.model_uri Signed-off-by: swathi <k…
KonakanchiSwathi Dec 21, 2023
0b5ab75
Signed-off-by: swathi <konakanchi.swathi@gmail.com>
KonakanchiSwathi Dec 21, 2023
e73a548
Signed-off-by: swathi <konakanchi.swathi@gmail.com>
KonakanchiSwathi Dec 21, 2023
2c54870
Update mlflow/transformers/__init__.py
KonakanchiSwathi Dec 21, 2023
3973ffb
Signed-off-by: "v-swathikon konakanchi.swathi@gmail.com"
KonakanchiSwathi Dec 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
87 changes: 77 additions & 10 deletions mlflow/transformers/__init__.py
Expand Up @@ -1344,7 +1344,6 @@ def _should_add_pyfunc_to_model(pipeline) -> bool:
"DocumentQuestionAnsweringPipeline",
"ImageToTextPipeline",
"VisualQuestionAnsweringPipeline",
"ImageClassificationPipeline",
"ImageSegmentationPipeline",
"DepthEstimationPipeline",
"ObjectDetectionPipeline",
Expand All @@ -1354,11 +1353,6 @@ def _should_add_pyfunc_to_model(pipeline) -> bool:
"ZeroShotAudioClassificationPipeline",
]

impermissible_attrs = {"image_processor"}

for attr in impermissible_attrs:
if getattr(pipeline, attr, None) is not None:
return False
for model_type in exclusion_model_types:
if hasattr(transformers, model_type):
if isinstance(pipeline.model, getattr(transformers, model_type)):
Expand Down Expand Up @@ -1426,7 +1420,13 @@ def _get_default_pipeline_signature(pipeline, example=None, model_config=None) -
return ModelSignature(
inputs=Schema([ColSpec("string")]), outputs=Schema([ColSpec("string")])
)
elif isinstance(pipeline, transformers.TextClassificationPipeline):
elif isinstance(
pipeline,
(
transformers.TextClassificationPipeline,
transformers.ImageClassificationPipeline,
),
):
return ModelSignature(
inputs=Schema([ColSpec("string")]),
outputs=Schema([ColSpec("string", name="label"), ColSpec("double", name="score")]),
Expand Down Expand Up @@ -1816,6 +1816,9 @@ def _predict(self, data):
output_key = "token_str"
elif isinstance(self.pipeline, transformers.TextClassificationPipeline):
output_key = "label"
elif isinstance(self.pipeline, transformers.ImageClassificationPipeline):
data = self._convert_image_input(data)
output_key = "label"
elif isinstance(self.pipeline, transformers.ZeroShotClassificationPipeline):
output_key = "labels"
data = self._parse_json_encoded_list(data, "candidate_labels")
Expand Down Expand Up @@ -1894,7 +1897,11 @@ def _predict(self, data):
output = json.dumps(raw_output)
elif isinstance(
self.pipeline,
(transformers.AudioClassificationPipeline, transformers.TextClassificationPipeline),
(
transformers.AudioClassificationPipeline,
transformers.TextClassificationPipeline,
transformers.ImageClassificationPipeline,
),
):
return pd.DataFrame(raw_output)
else:
Expand Down Expand Up @@ -2581,6 +2588,65 @@ def _convert_cast_lists_from_np_back_to_list(data):
parsed_data.append(entry)
return parsed_data

@staticmethod
def is_base64_image(image):
"""Check whether input image is a base64 encoded"""

try:
return base64.b64encode(base64.b64decode(image)).decode("utf-8") == image
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is decode("utf-8") necessary? If not, could we reuse is_base64?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without the decode part, they are not matching

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you don't decode the image when you read it, then it matches:
image = base64.b64encode(read_image("cat_image.jpg"))
base64.b64encode(base64.b64decode(image)) == image

>>> with open("/Users/serena.ruan/Documents/repos/mlflow/tests/datasets/cat.png", "rb") as f:
...     image = f.read()
... 
>>> import base64
>>> image = base64.b64encode(image)
>>> base64.b64encode(base64.b64decode(image)) == image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if isinstance(image, str):
    if image.startswith("http://") or image.startswith("https://"):
        # We need to actually check for a real protocol, otherwise it's impossible to use a local file
        # like http_huggingface_co.png
        image = PIL.Image.open(requests.get(image, stream=True, timeout=timeout).raw)
    elif os.path.isfile(image):
        image = PIL.Image.open(image)
    else:
        if image.startswith("data:image/"):
            image = image.split(",")[1]

        # Try to load as base64
        try:
            b64 = base64.b64decode(image, validate=True)
            image = PIL.Image.open(BytesIO(b64))
        except Exception as e:
            raise ValueError(
                f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
            )

transformers accepts the base64 image in a string format. It's not accepting direct b64 encoded image format which is in bytes.

except binascii.Error:
return False

def _convert_image_input(self, input_data):
"""
Conversion utility for decoding the base64 encoded bytes data of a raw image file when
parsed through model serving, if applicable. Direct usage of the pyfunc implementation
outside of model serving will treat this utility as a noop.

For reference, the expected encoding for input to Model Serving will be:

import requests
import base64

response = requests.get("https://www.my.images/a/sound/file.jpg")
encoded_image = base64.b64encode(response.content).decode("utf-8")

inference_data = json.dumps({"inputs": [encoded_image]})

or

inference_df = pd.DataFrame(
pd.Series([encoded_image], name="image_file")
)
split_dict = {"dataframe_split": inference_df.to_dict(orient="split")}
split_json = json.dumps(split_dict)

or

records_dict = {"dataframe_records": inference_df.to_dict(orient="records")}
records_json = json.dumps(records_dict)

This utility will convert this JSON encoded, base64 encoded text back into bytes for
input into the Image pipelines for inference.
"""

def process_input_element(input_element):
input_value = next(iter(input_element.values()))
if isinstance(input_value, str) and not self.is_base64_image(input_value):
self._validate_str_input_uri_or_file(input_value)
return input_value

if isinstance(input_data, list) and all(
isinstance(element, dict) for element in input_data
):
# Use a list comprehension for readability
# the elimination of empty collection declarations
return [process_input_element(element) for element in input_data]
elif isinstance(input_data, str) and not self.is_base64_image(input_data):
self._validate_str_input_uri_or_file(input_data)

return input_data

def _convert_audio_input(self, data):
"""
Conversion utility for decoding the base64 encoded bytes data of a raw soundfile when
Expand Down Expand Up @@ -2669,7 +2735,8 @@ def decode_audio(encoded):
@staticmethod
def _validate_str_input_uri_or_file(input_str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we keep this method as it is and move your check on the data into a new function? This function's name is just like checking a single input string.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, thanks

"""
Validation of blob references to audio files, if a string is input to the ``predict``
Validation of blob references to either audio or image files,
if a string is input to the ``predict``
method, perform validation of the string contents by checking for a valid uri or
filesystem reference instead of surfacing the cryptic stack trace that is otherwise raised
for an invalid uri input.
Expand All @@ -2687,7 +2754,7 @@ def is_uri(s):
if not valid_uri:
raise MlflowException(
"An invalid string input was provided. String inputs to "
"audio files must be either a file location or a uri.",
"audio or image files must be either a file location or a uri.",
KonakanchiSwathi marked this conversation as resolved.
Show resolved Hide resolved
error_code=BAD_REQUEST,
)

Expand Down
175 changes: 164 additions & 11 deletions tests/transformers/test_transformers_model_export.py
Expand Up @@ -50,6 +50,7 @@
_record_pipeline_components,
_should_add_pyfunc_to_model,
_TransformersModel,
_TransformersWrapper,
_validate_transformers_task_type,
_write_card_data,
get_default_conda_env,
Expand Down Expand Up @@ -80,7 +81,8 @@
# runners#supported-runners-and-hardware-resources for instance specs.
RUNNING_IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
GITHUB_ACTIONS_SKIP_REASON = "Test consumes too much memory"

image_url = "https://raw.githubusercontent.com/mlflow/mlflow/master/tests/datasets/cat.png"
image_file_path = pathlib.Path(pathlib.Path(__file__).parent.parent, "datasets", "cat.png")
# Test that can only be run locally:
# - Summarization pipeline tests
# - TextClassifier pipeline tests
Expand Down Expand Up @@ -479,7 +481,7 @@ def test_instance_extraction(small_qa_pipeline):
("small_qa_pipeline", True),
("small_seq2seq_pipeline", True),
("small_multi_modal_pipeline", False),
("small_vision_model", False),
("small_vision_model", True),
],
)
def test_pipeline_eligibility_for_pyfunc_registration(model, result, request):
Expand Down Expand Up @@ -599,8 +601,7 @@ def test_model_card_acquisition_vision_model(small_vision_model):
def test_vision_model_save_pipeline_with_defaults(small_vision_model, model_path):
mlflow.transformers.save_model(transformers_model=small_vision_model, path=model_path)
# validate inferred pip requirements
with model_path.joinpath("requirements.txt").open() as file:
requirements = file.read()
requirements = model_path.joinpath("requirements.txt").read_text()
reqs = {req.split("==")[0] for req in requirements.split("\n")}
expected_requirements = {"torch", "torchvision", "transformers"}
assert reqs.intersection(expected_requirements) == expected_requirements
Expand All @@ -625,6 +626,29 @@ def test_vision_model_save_pipeline_with_defaults(small_vision_model, model_path
assert flavor_config["source_model_name"] == "google/mobilenet_v2_1.0_224"


def test_vision_model_save_model_for_task_and_card_inference(small_vision_model, model_path):
mlflow.transformers.save_model(transformers_model=small_vision_model, path=model_path)
# validate inferred pip requirements
requirements = model_path.joinpath("requirements.txt").read_text()
reqs = {req.split("==")[0] for req in requirements.split("\n")}
expected_requirements = {"torch", "torchvision", "transformers"}
assert reqs.intersection(expected_requirements) == expected_requirements
# validate inferred model card data
card_data = yaml.safe_load(model_path.joinpath("model_card_data.yaml").read_bytes())
assert card_data["tags"] == ["vision", "image-classification"]
# Validate inferred model card text
card_text = model_path.joinpath("model_card.md").read_text(encoding="utf-8")
assert len(card_text) > 0

# Validate the MLModel file
mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes())
flavor_config = mlmodel["flavors"]["transformers"]
assert flavor_config["instance_type"] == "ImageClassificationPipeline"
assert flavor_config["pipeline_model_type"] == "MobileNetV2ForImageClassification"
assert flavor_config["task"] == "image-classification"
assert flavor_config["source_model_name"] == "google/mobilenet_v2_1.0_224"


def test_qa_model_save_model_for_task_and_card_inference(small_seq2seq_pipeline, model_path):
mlflow.transformers.save_model(
transformers_model={
Expand Down Expand Up @@ -971,11 +995,6 @@ def test_transformers_log_model_with_no_registered_model_name(small_vision_model
conda_env=str(conda_env),
)
mlflow.tracking._model_registry.fluent._register_model.assert_not_called()
model_uri = f"runs:/{mlflow.active_run().info.run_id}/{artifact_path}"
model_path = pathlib.Path(_download_artifact_from_uri(artifact_uri=model_uri))
model_config = Model.load(str(model_path.joinpath("MLmodel")))
# Vision models can't be loaded as pyfunc currently.
assert pyfunc.FLAVOR_NAME not in model_config.flavors


def test_transformers_save_persists_requirements_in_mlflow_directory(
Expand Down Expand Up @@ -1341,6 +1360,40 @@ def test_qa_pipeline_pyfunc_load_and_infer(small_qa_pipeline, model_path, infere
assert all(isinstance(element, str) for element in inference)


@pytest.mark.parametrize(
"inference_payload",
[
image_url,
str(image_file_path),
pytest.param(
"base64",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("4.33"),
reason="base64 feature not present",
),
),
],
)
def test_vision_pipeline_pyfunc_load_and_infer(small_vision_model, model_path, inference_payload):
if inference_payload == "base64":
inference_payload = base64.b64encode(image_file_path.read_bytes()).decode("utf-8")
signature = infer_signature(
inference_payload,
mlflow.transformers.generate_signature_output(small_vision_model, inference_payload),
)
mlflow.transformers.save_model(
transformers_model=small_vision_model,
path=model_path,
signature=signature,
)
pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
predictions = pyfunc_loaded.predict(inference_payload)

transformers_loaded_model = mlflow.transformers.load_model(model_path)
expected_predictions = transformers_loaded_model.predict(inference_payload)
assert list(predictions.to_dict("records")[0].values()) == expected_predictions


@pytest.mark.parametrize(
("data", "result"),
[
Expand Down Expand Up @@ -2062,6 +2115,65 @@ def test_qa_pipeline_pyfunc_predict(small_qa_pipeline):
assert values.to_dict(orient="records") == [{0: "Run"}]


@pytest.mark.parametrize(
("input_image", "result"),
[
(str(image_file_path), False),
(image_url, False),
("base64", True),
("random string", False),
],
)
def test_vision_is_base64_image(input_image, result):
if input_image == "base64":
input_image = base64.b64encode(image_file_path.read_bytes()).decode("utf-8")
assert _TransformersWrapper.is_base64_image(input_image) == result


@pytest.mark.parametrize(
"inference_payload",
[
[str(image_file_path)],
[image_url],
pytest.param(
"base64",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("4.33"),
reason="base64 feature not present",
),
),
],
)
def test_vision_pipeline_pyfunc_predict(small_vision_model, inference_payload):
if inference_payload == "base64":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not embed complex logic like this that mutates the parameter based on string matching. Just create another test explicitly for this condition.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, moved this logic inside because if we do this as an input parameter a lengthy base64 string is being printed in the test suite. now we have to duplicate three tests if we want to create a new test for base64.

inference_payload = [
base64.b64encode(image_file_path.read_bytes()).decode("utf-8"),
]
artifact_path = "image_classification_model"

# Log the image classification model
with mlflow.start_run():
mlflow.transformers.log_model(
transformers_model=small_vision_model,
artifact_path=artifact_path,
)
model_uri = mlflow.get_artifact_uri(artifact_path)
pyfunc_inference_payload = json.dumps({"inputs": inference_payload})
response = pyfunc_serve_and_score_model(
model_uri,
data=pyfunc_inference_payload,
content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
extra_args=["--env-manager", "local"],
)
KonakanchiSwathi marked this conversation as resolved.
Show resolved Hide resolved

predictions = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()

transformers_loaded_model = mlflow.transformers.load_model(model_uri)
expected_predictions = transformers_loaded_model.predict(inference_payload)

assert [list(pred.values()) for pred in predictions.to_dict("records")] == expected_predictions


def test_classifier_pipeline_pyfunc_predict(text_classification_pipeline):
artifact_path = "text_classifier_model"
with mlflow.start_run():
Expand Down Expand Up @@ -3486,6 +3598,49 @@ def test_save_model_card_with_non_utf_characters(tmp_path, model_name):
assert data == card_data.data.to_dict()


def test_vision_pipeline_pyfunc_predict_with_kwargs(small_vision_model):
artifact_path = "image_classification_model"

parameters = {
"top_k": 2,
}
inference_payload = json.dumps(
{
"inputs": [image_url],
"params": parameters,
}
)

with mlflow.start_run():
mlflow.transformers.log_model(
transformers_model=small_vision_model,
artifact_path=artifact_path,
signature=infer_signature(
image_url,
mlflow.transformers.generate_signature_output(small_vision_model, image_url),
params=parameters,
),
)
model_uri = mlflow.get_artifact_uri(artifact_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar here, let's use model_info.model_uri


transformers_loaded_model = mlflow.transformers.load_model(model_uri)
expected_predictions = transformers_loaded_model.predict(image_url)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use transformers_loaded_model.predict(image_url, params=parameters) and avoid limiting result on line 3641

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transformers_loaded_model.predict(image_url, params=parameters) -> this is not supported.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, you could use mlflow.pyfunc.load_model(model_uri) then predict with parameters

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here compared the mlflow.pyfunc.load_model(model_uri).predict(data, params) == mlflow.transformers.load_model(model_uri).predict(data)[: top_k]


response = pyfunc_serve_and_score_model(
model_uri,
data=inference_payload,
content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
extra_args=["--env-manager", "local"],
)

predictions = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()

assert (
list(predictions.to_dict("records")[0].values())
== expected_predictions[: parameters["top_k"]]
)


def test_qa_pipeline_pyfunc_predict_with_kwargs(small_qa_pipeline):
artifact_path = "qa_model"
data = {
Expand Down Expand Up @@ -3875,9 +4030,7 @@ def test_basic_model_with_accelerate_homogeneous_mapping_works(tmp_path):
mlflow.transformers.save_model(transformers_model=pipeline, path=str(tmp_path / "model"))

loaded = mlflow.transformers.load_model(str(tmp_path / "model"))

text = "Apples are delicious"

assert loaded(text) == pipeline(text)


Expand Down