Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions ads/model/generic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2497,6 +2497,7 @@ def predict(
self,
data: Any = None,
auto_serialize_data: bool = False,
local: bool = False,
**kwargs,
) -> Dict[str, Any]:
"""Returns prediction of input data run against the model deployment endpoint.
Expand All @@ -2521,6 +2522,8 @@ def predict(
Whether to auto serialize input data. Defauls to `False` for GenericModel, and `True` for other frameworks.
`data` required to be json serializable if `auto_serialize_data=False`.
If `auto_serialize_data` set to True, data will be serialized before sending to model deployment endpoint.
local: bool.
Whether to invoke the prediction locally. Default to False.
kwargs:
content_type: str, used to indicate the media type of the resource.
image: PIL.Image Object or uri for the image.
Expand All @@ -2539,10 +2542,21 @@ def predict(
NotActiveDeploymentError
If model deployment process was not started or not finished yet.
ValueError
If `data` is empty or not JSON serializable.
If model is not deployed yet or the endpoint information is not available.
"""
if not self.model_deployment:
raise ValueError("Use `deploy()` method to start model deployment.")
if local:
return self.verify(
data=data, auto_serialize_data=auto_serialize_data, **kwargs
)

if not (self.model_deployment and self.model_deployment.url):
raise ValueError(
"Error invoking the remote endpoint as the model is not "
"deployed yet or the endpoint information is not available. "
"Use `deploy()` method to start model deployment. "
"If you intend to invoke inference using locally available "
"model artifact, set parameter `local=True`"
)

current_state = self.model_deployment.state.name.upper()
if current_state != ModelDeploymentState.ACTIVE.name:
Expand Down
77 changes: 53 additions & 24 deletions tests/unitary/with_extras/model/test_generic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,20 @@

INFERENCE_CONDA_ENV = "oci://bucket@namespace/<path_to_service_pack>"
TRAINING_CONDA_ENV = "oci://bucket@namespace/<path_to_service_pack>"
DEFAULT_PYTHON_VERSION = "3.8"
MODEL_FILE_NAME = "fake_model_name"
FAKE_MD_URL = "http://<model-deployment-url>"


def _prepare(model):
model.prepare(
inference_conda_env=INFERENCE_CONDA_ENV,
inference_python_version=DEFAULT_PYTHON_VERSION,
training_conda_env=TRAINING_CONDA_ENV,
training_python_version=DEFAULT_PYTHON_VERSION,
model_file_name=MODEL_FILE_NAME,
force_overwrite=True,
)


class TestEstimator:
Expand Down Expand Up @@ -315,14 +329,7 @@ def test_prepare_with_custom_scorepy(self, mock_signer):
@patch("ads.common.auth.default_signer")
def test_verify_without_reload(self, mock_signer):
"""Test verify input data without reload artifacts."""
self.generic_model.prepare(
inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1",
inference_python_version="3.6",
training_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1",
training_python_version="3.7",
model_file_name="fake_model_name",
force_overwrite=True,
)
_prepare(self.generic_model)
self.generic_model.verify(self.X_test.tolist())

with patch("ads.model.artifact.ModelArtifact.reload") as mock_reload:
Expand All @@ -332,20 +339,10 @@ def test_verify_without_reload(self, mock_signer):
@patch("ads.common.auth.default_signer")
def test_verify(self, mock_signer):
"""Test verify input data"""
self.generic_model.prepare(
inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1",
inference_python_version="3.6",
training_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1",
training_python_version="3.7",
model_file_name="fake_model_name",
force_overwrite=True,
)
_prepare(self.generic_model)
prediction_1 = self.generic_model.verify(self.X_test.tolist())
assert isinstance(prediction_1, dict), "Failed to verify json payload."

prediction_2 = self.generic_model.verify(self.X_test.tolist())
assert isinstance(prediction_2, dict), "Failed to verify input data."

def test_reload(self):
"""test the reload."""
pass
Expand Down Expand Up @@ -637,11 +634,31 @@ def test_deploy_with_default_display_name(self, mock_deploy):
== random_name[:-9]
)

@pytest.mark.parametrize("input_data", [(X_test.tolist())])
@patch("ads.common.auth.default_signer")
def test_predict_locally(self, mock_signer, input_data):
_prepare(self.generic_model)
test_result = self.generic_model.predict(data=input_data, local=True)
expected_result = self.generic_model.estimator.predict(input_data).tolist()
assert (
test_result["prediction"] == expected_result
), "Failed to verify input data."

with patch("ads.model.artifact.ModelArtifact.reload") as mock_reload:
self.generic_model.predict(
data=input_data, local=True, reload_artifacts=False
)
mock_reload.assert_not_called()

@patch.object(ModelDeployment, "predict")
@patch("ads.common.auth.default_signer")
@patch("ads.common.oci_client.OCIClientFactory")
@patch(
"ads.model.deployment.model_deployment.ModelDeployment.url",
return_value=FAKE_MD_URL,
)
def test_predict_with_not_active_deployment_fail(
self, mock_client, mock_signer, mock_predict
self, mock_url, mock_client, mock_signer, mock_predict
):
"""Ensures predict model fails in case of model deployment is not in an active state."""
with pytest.raises(NotActiveDeploymentError):
Expand All @@ -661,7 +678,11 @@ def test_predict_with_not_active_deployment_fail(

@patch("ads.common.auth.default_signer")
@patch("ads.common.oci_client.OCIClientFactory")
def test_predict_bytes_success(self, mock_client, mock_signer):
@patch(
"ads.model.deployment.model_deployment.ModelDeployment.url",
return_value=FAKE_MD_URL,
)
def test_predict_bytes_success(self, mock_url, mock_client, mock_signer):
"""Ensures predict model passes with bytes input."""
with patch.object(
ModelDeployment, "state", new_callable=PropertyMock
Expand All @@ -670,7 +691,7 @@ def test_predict_bytes_success(self, mock_client, mock_signer):
with patch.object(ModelDeployment, "predict") as mock_predict:
mock_predict.return_value = {"result": "result"}
self.generic_model.model_deployment = ModelDeployment(
model_deployment_id="test"
model_deployment_id="test",
)
# self.generic_model.model_deployment.current_state = ModelDeploymentState.ACTIVE
self.generic_model._as_onnx = False
Expand All @@ -683,7 +704,11 @@ def test_predict_bytes_success(self, mock_client, mock_signer):

@patch("ads.common.auth.default_signer")
@patch("ads.common.oci_client.OCIClientFactory")
def test_predict_success(self, mock_client, mock_signer):
@patch(
"ads.model.deployment.model_deployment.ModelDeployment.url",
return_value=FAKE_MD_URL,
)
def test_predict_success(self, mock_url, mock_client, mock_signer):
"""Ensures predict model passes with valid input parameters."""
with patch.object(
ModelDeployment, "state", new_callable=PropertyMock
Expand Down Expand Up @@ -800,7 +825,11 @@ def test_from_model_artifact(

@patch("ads.common.auth.default_signer")
@patch("ads.common.oci_client.OCIClientFactory")
def test_predict_success__serialize_input(self, mock_client, mock_signer):
@patch(
"ads.model.deployment.model_deployment.ModelDeployment.url",
return_value=FAKE_MD_URL,
)
def test_predict_success__serialize_input(self, mock_url, mock_client, mock_signer):
"""Ensures predict model passes with valid input parameters."""

df = pd.DataFrame([1, 2, 3])
Expand Down