168168 "training_script" : None ,
169169}
170170
171- INFERENCE_CONDA_ENV = "oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1"
172- TRAINING_CONDA_ENV = "oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1"
171+ INFERENCE_CONDA_ENV = "oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1"
172+ TRAINING_CONDA_ENV = "oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1"
173173DEFAULT_PYTHON_VERSION = "3.8"
174174MODEL_FILE_NAME = "fake_model_name"
175+ FAKE_MD_URL = "http://<model-deployment-url>"
176+
175177
176178def _prepare (model ):
177179 model .prepare (
@@ -182,13 +184,14 @@ def _prepare(model):
182184 model_file_name = MODEL_FILE_NAME ,
183185 force_overwrite = True ,
184186 )
187+
188+
185189class TestEstimator :
186190 def predict (self , x ):
187191 return x ** 2
188192
189193
190194class TestGenericModel :
191-
192195 iris = load_iris ()
193196 X , y = iris .data , iris .target
194197 X_train , X_test , y_train , y_test = train_test_split (X , y )
@@ -618,26 +621,31 @@ def test_deploy_with_default_display_name(self, mock_deploy):
618621 == random_name [:- 9 ]
619622 )
620623
621- @pytest .mark .parametrize (
622- "input_data" ,
623- [(X_test .tolist ())]
624- )
624+ @pytest .mark .parametrize ("input_data" , [(X_test .tolist ())])
625625 @patch ("ads.common.auth.default_signer" )
626626 def test_predict_locally (self , mock_signer , input_data ):
627627 _prepare (self .generic_model )
628628 test_result = self .generic_model .predict (data = input_data , local = True )
629629 expected_result = self .generic_model .estimator .predict (input_data ).tolist ()
630- assert test_result ['prediction' ] == expected_result , "Failed to verify input data."
630+ assert (
631+ test_result ["prediction" ] == expected_result
632+ ), "Failed to verify input data."
631633
632634 with patch ("ads.model.artifact.ModelArtifact.reload" ) as mock_reload :
633- self .generic_model .predict (data = input_data , local = True , reload_artifacts = False )
635+ self .generic_model .predict (
636+ data = input_data , local = True , reload_artifacts = False
637+ )
634638 mock_reload .assert_not_called ()
635639
636640 @patch .object (ModelDeployment , "predict" )
637641 @patch ("ads.common.auth.default_signer" )
638642 @patch ("ads.common.oci_client.OCIClientFactory" )
643+ @patch (
644+ "ads.model.deployment.model_deployment.ModelDeployment.url" ,
645+ return_value = FAKE_MD_URL ,
646+ )
639647 def test_predict_with_not_active_deployment_fail (
640- self , mock_client , mock_signer , mock_predict
648+ self , mock_url , mock_client , mock_signer , mock_predict
641649 ):
642650 """Ensures predict model fails in case of model deployment is not in an active state."""
643651 with pytest .raises (NotActiveDeploymentError ):
@@ -657,7 +665,11 @@ def test_predict_with_not_active_deployment_fail(
657665
658666 @patch ("ads.common.auth.default_signer" )
659667 @patch ("ads.common.oci_client.OCIClientFactory" )
660- def test_predict_bytes_success (self , mock_client , mock_signer ):
668+ @patch (
669+ "ads.model.deployment.model_deployment.ModelDeployment.url" ,
670+ return_value = FAKE_MD_URL ,
671+ )
672+ def test_predict_bytes_success (self , mock_url , mock_client , mock_signer ):
661673 """Ensures predict model passes with bytes input."""
662674 with patch .object (
663675 ModelDeployment , "state" , new_callable = PropertyMock
@@ -666,7 +678,7 @@ def test_predict_bytes_success(self, mock_client, mock_signer):
666678 with patch .object (ModelDeployment , "predict" ) as mock_predict :
667679 mock_predict .return_value = {"result" : "result" }
668680 self .generic_model .model_deployment = ModelDeployment (
669- model_deployment_id = "test"
681+ model_deployment_id = "test" ,
670682 )
671683 # self.generic_model.model_deployment.current_state = ModelDeploymentState.ACTIVE
672684 self .generic_model ._as_onnx = False
@@ -679,7 +691,11 @@ def test_predict_bytes_success(self, mock_client, mock_signer):
679691
680692 @patch ("ads.common.auth.default_signer" )
681693 @patch ("ads.common.oci_client.OCIClientFactory" )
682- def test_predict_success (self , mock_client , mock_signer ):
694+ @patch (
695+ "ads.model.deployment.model_deployment.ModelDeployment.url" ,
696+ return_value = FAKE_MD_URL ,
697+ )
698+ def test_predict_success (self , mock_url , mock_client , mock_signer ):
683699 """Ensures predict model passes with valid input parameters."""
684700 with patch .object (
685701 ModelDeployment , "state" , new_callable = PropertyMock
@@ -796,7 +812,11 @@ def test_from_model_artifact(
796812
797813 @patch ("ads.common.auth.default_signer" )
798814 @patch ("ads.common.oci_client.OCIClientFactory" )
799- def test_predict_success__serialize_input (self , mock_client , mock_signer ):
815+ @patch (
816+ "ads.model.deployment.model_deployment.ModelDeployment.url" ,
817+ return_value = FAKE_MD_URL ,
818+ )
819+ def test_predict_success__serialize_input (self , mock_url , mock_client , mock_signer ):
800820 """Ensures predict model passes with valid input parameters."""
801821
802822 df = pd .DataFrame ([1 , 2 , 3 ])
@@ -806,7 +826,6 @@ def test_predict_success__serialize_input(self, mock_client, mock_signer):
806826 with patch .object (
807827 GenericModel , "get_data_serializer"
808828 ) as mock_get_data_serializer :
809-
810829 mock_get_data_serializer .return_value .data = df .to_json ()
811830 mock_state .return_value = ModelDeploymentState .ACTIVE
812831 with patch .object (ModelDeployment , "predict" ) as mock_predict :
@@ -1793,7 +1812,6 @@ def test_upload_artifact_fail(self):
17931812 def test_upload_artifact_success (self ):
17941813 """Tests uploading model artifacts to the provided `uri`."""
17951814 with tempfile .TemporaryDirectory () as tmp_dir :
1796-
17971815 # copy test artifacts to the temp folder
17981816 shutil .copytree (
17991817 os .path .join (self .curr_dir , "test_files/valid_model_artifacts" ),
0 commit comments