Skip to content

Commit d61245a

Browse files
committed
fixed by comments: improved error message
1 parent 068c49e commit d61245a

File tree

2 files changed

+46
-20
lines changed

2 files changed

+46
-20
lines changed

ads/model/generic_model.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2528,13 +2528,21 @@ def predict(
25282528
NotActiveDeploymentError
25292529
If model deployment process was not started or not finished yet.
25302530
ValueError
2531-
If `data` is empty or not JSON serializable.
2531+
If model is not deployed yet or the endpoint information is not available.
25322532
"""
25332533
if local:
2534-
return self.verify(data=data, auto_serialize_data=auto_serialize_data, **kwargs)
2534+
return self.verify(
2535+
data=data, auto_serialize_data=auto_serialize_data, **kwargs
2536+
)
25352537

2536-
if not self.model_deployment:
2537-
raise ValueError("Use `deploy()` method to start model deployment.")
2538+
if not (self.model_deployment and self.model_deployment.url):
2539+
raise ValueError(
2540+
"Error invoking the remote endpoint as the model is not "
2541+
"deployed yet or the endpoint information is not available. "
2542+
"Use `deploy()` method to start model deployment. "
2543+
"If you intend to invoke inference using locally available "
2544+
"model artifact, set parameter `local=True`"
2545+
)
25382546

25392547
current_state = self.model_deployment.state.name.upper()
25402548
if current_state != ModelDeploymentState.ACTIVE.name:

tests/unitary/with_extras/model/test_generic_model.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,12 @@
168168
"training_script": None,
169169
}
170170

171-
INFERENCE_CONDA_ENV= "oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1"
172-
TRAINING_CONDA_ENV="oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1"
171+
INFERENCE_CONDA_ENV = "oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1"
172+
TRAINING_CONDA_ENV = "oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1"
173173
DEFAULT_PYTHON_VERSION = "3.8"
174174
MODEL_FILE_NAME = "fake_model_name"
175+
FAKE_MD_URL = "http://<model-deployment-url>"
176+
175177

176178
def _prepare(model):
177179
model.prepare(
@@ -182,13 +184,14 @@ def _prepare(model):
182184
model_file_name=MODEL_FILE_NAME,
183185
force_overwrite=True,
184186
)
187+
188+
185189
class TestEstimator:
186190
def predict(self, x):
187191
return x**2
188192

189193

190194
class TestGenericModel:
191-
192195
iris = load_iris()
193196
X, y = iris.data, iris.target
194197
X_train, X_test, y_train, y_test = train_test_split(X, y)
@@ -618,26 +621,31 @@ def test_deploy_with_default_display_name(self, mock_deploy):
618621
== random_name[:-9]
619622
)
620623

621-
@pytest.mark.parametrize(
622-
"input_data",
623-
[(X_test.tolist())]
624-
)
624+
@pytest.mark.parametrize("input_data", [(X_test.tolist())])
625625
@patch("ads.common.auth.default_signer")
626626
def test_predict_locally(self, mock_signer, input_data):
627627
_prepare(self.generic_model)
628628
test_result = self.generic_model.predict(data=input_data, local=True)
629629
expected_result = self.generic_model.estimator.predict(input_data).tolist()
630-
assert test_result['prediction'] == expected_result, "Failed to verify input data."
630+
assert (
631+
test_result["prediction"] == expected_result
632+
), "Failed to verify input data."
631633

632634
with patch("ads.model.artifact.ModelArtifact.reload") as mock_reload:
633-
self.generic_model.predict(data=input_data, local=True, reload_artifacts=False)
635+
self.generic_model.predict(
636+
data=input_data, local=True, reload_artifacts=False
637+
)
634638
mock_reload.assert_not_called()
635639

636640
@patch.object(ModelDeployment, "predict")
637641
@patch("ads.common.auth.default_signer")
638642
@patch("ads.common.oci_client.OCIClientFactory")
643+
@patch(
644+
"ads.model.deployment.model_deployment.ModelDeployment.url",
645+
return_value=FAKE_MD_URL,
646+
)
639647
def test_predict_with_not_active_deployment_fail(
640-
self, mock_client, mock_signer, mock_predict
648+
self, mock_url, mock_client, mock_signer, mock_predict
641649
):
642650
"""Ensures predict model fails in case of model deployment is not in an active state."""
643651
with pytest.raises(NotActiveDeploymentError):
@@ -657,7 +665,11 @@ def test_predict_with_not_active_deployment_fail(
657665

658666
@patch("ads.common.auth.default_signer")
659667
@patch("ads.common.oci_client.OCIClientFactory")
660-
def test_predict_bytes_success(self, mock_client, mock_signer):
668+
@patch(
669+
"ads.model.deployment.model_deployment.ModelDeployment.url",
670+
return_value=FAKE_MD_URL,
671+
)
672+
def test_predict_bytes_success(self, mock_url, mock_client, mock_signer):
661673
"""Ensures predict model passes with bytes input."""
662674
with patch.object(
663675
ModelDeployment, "state", new_callable=PropertyMock
@@ -666,7 +678,7 @@ def test_predict_bytes_success(self, mock_client, mock_signer):
666678
with patch.object(ModelDeployment, "predict") as mock_predict:
667679
mock_predict.return_value = {"result": "result"}
668680
self.generic_model.model_deployment = ModelDeployment(
669-
model_deployment_id="test"
681+
model_deployment_id="test",
670682
)
671683
# self.generic_model.model_deployment.current_state = ModelDeploymentState.ACTIVE
672684
self.generic_model._as_onnx = False
@@ -679,7 +691,11 @@ def test_predict_bytes_success(self, mock_client, mock_signer):
679691

680692
@patch("ads.common.auth.default_signer")
681693
@patch("ads.common.oci_client.OCIClientFactory")
682-
def test_predict_success(self, mock_client, mock_signer):
694+
@patch(
695+
"ads.model.deployment.model_deployment.ModelDeployment.url",
696+
return_value=FAKE_MD_URL,
697+
)
698+
def test_predict_success(self, mock_url, mock_client, mock_signer):
683699
"""Ensures predict model passes with valid input parameters."""
684700
with patch.object(
685701
ModelDeployment, "state", new_callable=PropertyMock
@@ -796,7 +812,11 @@ def test_from_model_artifact(
796812

797813
@patch("ads.common.auth.default_signer")
798814
@patch("ads.common.oci_client.OCIClientFactory")
799-
def test_predict_success__serialize_input(self, mock_client, mock_signer):
815+
@patch(
816+
"ads.model.deployment.model_deployment.ModelDeployment.url",
817+
return_value=FAKE_MD_URL,
818+
)
819+
def test_predict_success__serialize_input(self, mock_url, mock_client, mock_signer):
800820
"""Ensures predict model passes with valid input parameters."""
801821

802822
df = pd.DataFrame([1, 2, 3])
@@ -806,7 +826,6 @@ def test_predict_success__serialize_input(self, mock_client, mock_signer):
806826
with patch.object(
807827
GenericModel, "get_data_serializer"
808828
) as mock_get_data_serializer:
809-
810829
mock_get_data_serializer.return_value.data = df.to_json()
811830
mock_state.return_value = ModelDeploymentState.ACTIVE
812831
with patch.object(ModelDeployment, "predict") as mock_predict:
@@ -1793,7 +1812,6 @@ def test_upload_artifact_fail(self):
17931812
def test_upload_artifact_success(self):
17941813
"""Tests uploading model artifacts to the provided `uri`."""
17951814
with tempfile.TemporaryDirectory() as tmp_dir:
1796-
17971815
# copy test artifacts to the temp folder
17981816
shutil.copytree(
17991817
os.path.join(self.curr_dir, "test_files/valid_model_artifacts"),

0 commit comments

Comments
 (0)