diff --git a/ads/model/generic_model.py b/ads/model/generic_model.py index 226cb433f..0db88d51c 100644 --- a/ads/model/generic_model.py +++ b/ads/model/generic_model.py @@ -2497,6 +2497,7 @@ def predict( self, data: Any = None, auto_serialize_data: bool = False, + local: bool = False, **kwargs, ) -> Dict[str, Any]: """Returns prediction of input data run against the model deployment endpoint. @@ -2521,6 +2522,8 @@ def predict( Whether to auto serialize input data. Defauls to `False` for GenericModel, and `True` for other frameworks. `data` required to be json serializable if `auto_serialize_data=False`. If `auto_serialize_data` set to True, data will be serialized before sending to model deployment endpoint. + local: bool. + Whether to invoke the prediction locally. Default to False. kwargs: content_type: str, used to indicate the media type of the resource. image: PIL.Image Object or uri for the image. @@ -2539,10 +2542,21 @@ def predict( NotActiveDeploymentError If model deployment process was not started or not finished yet. ValueError - If `data` is empty or not JSON serializable. + If model is not deployed yet or the endpoint information is not available. """ - if not self.model_deployment: - raise ValueError("Use `deploy()` method to start model deployment.") + if local: + return self.verify( + data=data, auto_serialize_data=auto_serialize_data, **kwargs + ) + + if not (self.model_deployment and self.model_deployment.url): + raise ValueError( + "Error invoking the remote endpoint as the model is not " + "deployed yet or the endpoint information is not available. " + "Use `deploy()` method to start model deployment. " + "If you intend to invoke inference using locally available " + "model artifact, set parameter `local=True`" + ) current_state = self.model_deployment.state.name.upper() if current_state != ModelDeploymentState.ACTIVE.name: diff --git a/tests/unitary/with_extras/model/test_generic_model.py b/tests/unitary/with_extras/model/test_generic_model.py index f3f46e9aa..602b4e001 100644 --- a/tests/unitary/with_extras/model/test_generic_model.py +++ b/tests/unitary/with_extras/model/test_generic_model.py @@ -170,6 +170,20 @@ INFERENCE_CONDA_ENV = "oci://bucket@namespace/" TRAINING_CONDA_ENV = "oci://bucket@namespace/" +DEFAULT_PYTHON_VERSION = "3.8" +MODEL_FILE_NAME = "fake_model_name" +FAKE_MD_URL = "http://" + + +def _prepare(model): + model.prepare( + inference_conda_env=INFERENCE_CONDA_ENV, + inference_python_version=DEFAULT_PYTHON_VERSION, + training_conda_env=TRAINING_CONDA_ENV, + training_python_version=DEFAULT_PYTHON_VERSION, + model_file_name=MODEL_FILE_NAME, + force_overwrite=True, + ) class TestEstimator: @@ -315,14 +329,7 @@ def test_prepare_with_custom_scorepy(self, mock_signer): @patch("ads.common.auth.default_signer") def test_verify_without_reload(self, mock_signer): """Test verify input data without reload artifacts.""" - self.generic_model.prepare( - inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1", - inference_python_version="3.6", - training_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1", - training_python_version="3.7", - model_file_name="fake_model_name", - force_overwrite=True, - ) + _prepare(self.generic_model) self.generic_model.verify(self.X_test.tolist()) with patch("ads.model.artifact.ModelArtifact.reload") as mock_reload: @@ -332,20 +339,10 @@ def test_verify_without_reload(self, mock_signer): @patch("ads.common.auth.default_signer") def test_verify(self, mock_signer): """Test verify input data""" - self.generic_model.prepare( - inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1", - inference_python_version="3.6", - training_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1", - training_python_version="3.7", - model_file_name="fake_model_name", - force_overwrite=True, - ) + _prepare(self.generic_model) prediction_1 = self.generic_model.verify(self.X_test.tolist()) assert isinstance(prediction_1, dict), "Failed to verify json payload." - prediction_2 = self.generic_model.verify(self.X_test.tolist()) - assert isinstance(prediction_2, dict), "Failed to verify input data." - def test_reload(self): """test the reload.""" pass @@ -637,11 +634,31 @@ def test_deploy_with_default_display_name(self, mock_deploy): == random_name[:-9] ) + @pytest.mark.parametrize("input_data", [(X_test.tolist())]) + @patch("ads.common.auth.default_signer") + def test_predict_locally(self, mock_signer, input_data): + _prepare(self.generic_model) + test_result = self.generic_model.predict(data=input_data, local=True) + expected_result = self.generic_model.estimator.predict(input_data).tolist() + assert ( + test_result["prediction"] == expected_result + ), "Failed to verify input data." + + with patch("ads.model.artifact.ModelArtifact.reload") as mock_reload: + self.generic_model.predict( + data=input_data, local=True, reload_artifacts=False + ) + mock_reload.assert_not_called() + @patch.object(ModelDeployment, "predict") @patch("ads.common.auth.default_signer") @patch("ads.common.oci_client.OCIClientFactory") + @patch( + "ads.model.deployment.model_deployment.ModelDeployment.url", + return_value=FAKE_MD_URL, + ) def test_predict_with_not_active_deployment_fail( - self, mock_client, mock_signer, mock_predict + self, mock_url, mock_client, mock_signer, mock_predict ): """Ensures predict model fails in case of model deployment is not in an active state.""" with pytest.raises(NotActiveDeploymentError): @@ -661,7 +678,11 @@ def test_predict_with_not_active_deployment_fail( @patch("ads.common.auth.default_signer") @patch("ads.common.oci_client.OCIClientFactory") - def test_predict_bytes_success(self, mock_client, mock_signer): + @patch( + "ads.model.deployment.model_deployment.ModelDeployment.url", + return_value=FAKE_MD_URL, + ) + def test_predict_bytes_success(self, mock_url, mock_client, mock_signer): """Ensures predict model passes with bytes input.""" with patch.object( ModelDeployment, "state", new_callable=PropertyMock @@ -670,7 +691,7 @@ def test_predict_bytes_success(self, mock_client, mock_signer): with patch.object(ModelDeployment, "predict") as mock_predict: mock_predict.return_value = {"result": "result"} self.generic_model.model_deployment = ModelDeployment( - model_deployment_id="test" + model_deployment_id="test", ) # self.generic_model.model_deployment.current_state = ModelDeploymentState.ACTIVE self.generic_model._as_onnx = False @@ -683,7 +704,11 @@ def test_predict_bytes_success(self, mock_client, mock_signer): @patch("ads.common.auth.default_signer") @patch("ads.common.oci_client.OCIClientFactory") - def test_predict_success(self, mock_client, mock_signer): + @patch( + "ads.model.deployment.model_deployment.ModelDeployment.url", + return_value=FAKE_MD_URL, + ) + def test_predict_success(self, mock_url, mock_client, mock_signer): """Ensures predict model passes with valid input parameters.""" with patch.object( ModelDeployment, "state", new_callable=PropertyMock @@ -800,7 +825,11 @@ def test_from_model_artifact( @patch("ads.common.auth.default_signer") @patch("ads.common.oci_client.OCIClientFactory") - def test_predict_success__serialize_input(self, mock_client, mock_signer): + @patch( + "ads.model.deployment.model_deployment.ModelDeployment.url", + return_value=FAKE_MD_URL, + ) + def test_predict_success__serialize_input(self, mock_url, mock_client, mock_signer): """Ensures predict model passes with valid input parameters.""" df = pd.DataFrame([1, 2, 3])