Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions ads/model/deployment/model_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class LogNotConfiguredError(Exception): # pragma: no cover
pass


class ModelDeploymentFailedError(Exception): # pragma: no cover
class ModelDeploymentPredictError(Exception): # pragma: no cover
pass


Expand Down Expand Up @@ -607,11 +607,6 @@ def deploy(
-------
ModelDeployment
The instance of ModelDeployment.

Raises
------
ModelDeploymentFailedError
If model deployment fails to deploy
"""
create_model_deployment_details = (
self._build_model_deployment_details()
Expand All @@ -626,11 +621,6 @@ def deploy(
poll_interval=poll_interval,
)

if response.lifecycle_state == State.FAILED.name:
raise ModelDeploymentFailedError(
f"Model deployment {response.id} failed to deploy: {response.lifecycle_details}"
)

return self._update_from_oci_model(response)

def delete(
Expand Down Expand Up @@ -662,6 +652,7 @@ def delete(
max_wait_time=max_wait_time,
poll_interval=poll_interval,
)

return self._update_from_oci_model(response)

def update(
Expand Down Expand Up @@ -890,6 +881,12 @@ def predict(
Prediction results.

"""
current_state = self.sync().lifecycle_state
if current_state != State.ACTIVE.name:
raise ModelDeploymentPredictError(
"This model deployment is not in active state, you will not be able to use predict end point. "
f"Current model deployment state: {current_state} "
)
endpoint = f"{self.url}/predict"
signer = authutil.default_signer()["signer"]
header = {
Expand Down Expand Up @@ -953,7 +950,7 @@ def predict(
except oci.exceptions.ServiceError as ex:
# When bandwidth exceeds the allocated value, TooManyRequests error (429) will be raised by oci backend.
if ex.status == 429:
bandwidth_mbps = self.infrastructure.bandwidth_mbps or MODEL_DEPLOYMENT_BANDWIDTH_MBPS
bandwidth_mbps = self.infrastructure.bandwidth_mbps or DEFAULT_BANDWIDTH_MBPS
utils.get_logger().warning(
f"Load balancer bandwidth exceeds the allocated {bandwidth_mbps} Mbps."
"To estimate the actual bandwidth, use formula: (payload size in KB) * (estimated requests per second) * 8 / 1024."
Expand Down
5 changes: 4 additions & 1 deletion ads/model/generic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,7 +1574,10 @@ def from_model_deployment(

current_state = model_deployment.state.name.upper()
if current_state != ModelDeploymentState.ACTIVE.name:
raise NotActiveDeploymentError(current_state)
logger.warning(
"This model deployment is not in active state, you will not be able to use predict end point. "
f"Current model deployment state: `{current_state}`"
)

model = cls.from_model_catalog(
model_id=model_deployment.properties.model_id,
Expand Down
17 changes: 7 additions & 10 deletions ads/model/service/oci_datascience_model_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,8 @@ def activate(
)
except Exception as e:
logger.error(
f"Error while trying to activate model deployment: {self.id}"
"Error while trying to activate model deployment: " + str(e)
)
raise e

return self.sync()
else:
Expand Down Expand Up @@ -261,9 +260,8 @@ def create(
)
except Exception as e:
logger.error(
f"Error while trying to create model deployment: {self.id}"
"Error while trying to create model deployment: " + str(e)
)
raise e

return self.sync()

Expand Down Expand Up @@ -325,9 +323,8 @@ def deactivate(
)
except Exception as e:
logger.error(
f"Error while trying to deactivate model deployment: {self.id}"
"Error while trying to deactivate model deployment: " + str(e)
)
raise e

return self.sync()
else:
Expand Down Expand Up @@ -396,9 +393,8 @@ def delete(
)
except Exception as e:
logger.error(
f"Error while trying to delete model deployment: {self.id}"
"Error while trying to delete model deployment: " + str(e)
)
raise e

return self.sync()

Expand Down Expand Up @@ -452,8 +448,9 @@ def update(
)
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
except Exception as e:
logger.error(f"Error while trying to update model deployment: {self.id}")
raise e
logger.error(
"Error while trying to update model deployment: " + str(e)
)

return self.sync()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ class ModelDeploymentTestCase(unittest.TestCase):
)

@patch("requests.post")
def test_predict(self, mock_post):
@patch("ads.model.deployment.model_deployment.ModelDeployment.sync")
def test_predict(self, mock_sync, mock_post):
"""Ensures predict model passes with valid input parameters."""
mock_sync.return_value = Mock(lifecycle_state="ACTIVE")
mock_post.return_value = Mock(
status_code=200, json=lambda: {"result": "result"}
)
Expand All @@ -50,8 +52,10 @@ def test_predict(self, mock_post):
self.test_model_deployment.predict(data=np.array([1, 2, 3]))

@patch("requests.post")
def test_predict_with_bytes(self, mock_post):
@patch("ads.model.deployment.model_deployment.ModelDeployment.sync")
def test_predict_with_bytes(self, mock_sync, mock_post):
"""Ensures predict model passes with bytes input."""
mock_sync.return_value = Mock(lifecycle_state="ACTIVE")
byte_data = b"[[1,2,3,4]]"
with patch.object(authutil, "default_signer") as mock_auth:
auth = MagicMock()
Expand All @@ -66,8 +70,10 @@ def test_predict_with_bytes(self, mock_post):
)

@patch("requests.post")
def test_predict_with_auto_serialize_data(self, mock_post):
@patch("ads.model.deployment.model_deployment.ModelDeployment.sync")
def test_predict_with_auto_serialize_data(self, mock_sync, mock_post):
"""Ensures predict model passes with valid input parameters."""
mock_sync.return_value = Mock(lifecycle_state="ACTIVE")
mock_post.return_value = Mock(
status_code=200, json=lambda: {"result": "result"}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ads.model.deployment.model_deployment import (
ModelDeployment,
ModelDeploymentLogType,
ModelDeploymentFailedError,
)
from ads.model.deployment.model_deployment_infrastructure import (
ModelDeploymentInfrastructure,
Expand Down Expand Up @@ -1148,44 +1147,6 @@ def test_deploy(
mock_create_model_deployment.assert_called_with(create_model_deployment_details)
mock_sync.assert_called()

@patch.object(OCIDataScienceMixin, "sync")
@patch.object(
oci.data_science.DataScienceClient,
"create_model_deployment",
)
@patch.object(DataScienceModel, "create")
def test_deploy_failed(
self, mock_create, mock_create_model_deployment, mock_sync
):
dsc_model = MagicMock()
dsc_model.id = "fakeid.datasciencemodel.oc1.iad.xxx"
mock_create.return_value = dsc_model
response = oci.response.Response(
status=MagicMock(),
headers=MagicMock(),
request=MagicMock(),
data=oci.data_science.models.ModelDeployment(
id="test_model_deployment_id",
lifecycle_state="FAILED",
lifecycle_details="The specified log object is not found or user is not authorized.",
),
)
mock_sync.return_value = response.data
model_deployment = self.initialize_model_deployment()
create_model_deployment_details = (
model_deployment._build_model_deployment_details()
)
with pytest.raises(
ModelDeploymentFailedError,
match=f"Model deployment {response.data.id} failed to deploy: {response.data.lifecycle_details}",
):
model_deployment.deploy(wait_for_completion=False)
mock_create.assert_called()
mock_create_model_deployment.assert_called_with(
create_model_deployment_details
)
mock_sync.assert_called()

@patch.object(
OCIDataScienceModelDeployment,
"activate",
Expand Down
39 changes: 0 additions & 39 deletions tests/unitary/with_extras/model/test_generic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,45 +1051,6 @@ def test_from_model_deployment(

assert test_result == test_model

@patch.object(
ModelDeployment,
"state",
new_callable=PropertyMock,
return_value=ModelDeploymentState.FAILED,
)
@patch.object(ModelDeployment, "from_id")
@patch("ads.common.auth.default_signer")
@patch("ads.common.oci_client.OCIClientFactory")
def test_from_model_deployment_fail(
self,
mock_client,
mock_default_signer,
mock_from_id,
mock_model_deployment_state,
):
"""Tests loading model from model deployment."""
test_auth_config = {"signer": {"config": "value"}}
mock_default_signer.return_value = test_auth_config
test_model_deployment_id = "md_ocid"
test_model_id = "model_ocid"
md_props = ModelDeploymentProperties(model_id=test_model_id)
md = ModelDeployment(properties=md_props)
mock_from_id.return_value = md

with pytest.raises(NotActiveDeploymentError):
GenericModel.from_model_deployment(
model_deployment_id=test_model_deployment_id,
model_file_name="test.pkl",
artifact_dir="test_dir",
auth=test_auth_config,
force_overwrite=True,
properties=None,
bucket_uri="test_bucket_uri",
remove_existing_artifact=True,
compartment_id="test_compartment_id",
)
mock_from_id.assert_called_with(test_model_deployment_id)

@patch.object(ModelDeployment, "update")
@patch.object(ModelDeployment, "from_id")
@patch("ads.common.auth.default_signer")
Expand Down