Skip to content

Commit

Permalink
feat: Implement Model.copy functionality.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 516678124
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Mar 15, 2023
1 parent 8cb4377 commit 94dd82f
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 0 deletions.
119 changes: 119 additions & 0 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4656,6 +4656,125 @@ def upload_tensorflow_saved_model(
upload_request_timeout=upload_request_timeout,
)

# TODO(b/273499620): Add async support.
def copy(
self,
destination_location: str,
destination_model_id: Optional[str] = None,
destination_parent_model: Optional[str] = None,
encryption_spec_key_name: Optional[str] = None,
copy_request_timeout: Optional[float] = None,
) -> "Model":
"""Copys a model and returns a Model representing the copied Model
resource. This method is a blocking call.
Example usage:
copied_model = my_model.copy(
destination_location="us-central1"
)
Args:
destination_location (str):
The destination location to copy the model to.
destination_model_id (str):
Optional. The ID to use for the copied Model, which will
become the final component of the model resource name.
This value may be up to 63 characters, and valid characters
are `[a-z0-9_-]`. The first character cannot be a number or hyphen.
Only set this field when copying as a new model. If this field is not set,
a numeric model id will be generated.
destination_parent_model (str):
Optional. The resource name or model ID of an existing model that the
newly-copied model will be a version of.
Only set this field when copying as a new version of an existing model.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the model. Has the
form:
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
The key needs to be in the same region as where the compute
resource is created.
If set, this Model and all sub-resources of this Model will be secured by this key.
Overrides encryption_spec_key_name set in aiplatform.init.
copy_request_timeout (float):
Optional. The timeout for the copy request in seconds.
Returns:
model (aiplatform.Model):
Instantiated representation of the copied model resource.
Raises:
ValueError: If both `destination_model_id` and `destination_parent_model` are set.
"""
if destination_model_id is not None and destination_parent_model is not None:
raise ValueError(
"`destination_model_id` and `destination_parent_model` can not be set together."
)

parent = initializer.global_config.common_location_path(
initializer.global_config.project, destination_location
)

source_model = self.versioned_resource_name

destination_parent_model = ModelRegistry._get_true_version_parent(
parent_model=destination_parent_model,
project=initializer.global_config.project,
location=destination_location,
)

encryption_spec = initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name,
)

if destination_model_id is not None:
request = gca_model_service_compat.CopyModelRequest(
parent=parent,
source_model=source_model,
model_id=destination_model_id,
encryption_spec=encryption_spec,
)
else:
request = gca_model_service_compat.CopyModelRequest(
parent=parent,
source_model=source_model,
parent_model=destination_parent_model,
encryption_spec=encryption_spec,
)

api_client = initializer.global_config.create_client(
client_class=utils.ModelClientWithOverride,
location_override=destination_location,
credentials=initializer.global_config.credentials,
)

_LOGGER.log_action_start_against_resource("Copying", "", self)

lro = api_client.copy_model(
request=request,
timeout=copy_request_timeout,
)

_LOGGER.log_action_started_against_resource_with_lro(
"Copy", "", self.__class__, lro
)

model_copy_response = lro.result(timeout=None)

this_model = models.Model(
model_copy_response.model,
version=model_copy_response.model_version_id,
location=destination_location,
)

_LOGGER.log_action_completed_against_resource("", "copied", this_model)

return this_model

def list_model_evaluations(
self,
) -> List["model_evaluation.ModelEvaluation"]:
Expand Down
79 changes: 79 additions & 0 deletions tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
_TEST_MODEL_NAME = "123"
_TEST_MODEL_NAME_ALT = "456"
_TEST_MODEL_ID = "my-model"
_TEST_MODEL_PARENT = (
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_MODEL_NAME}"
)
Expand Down Expand Up @@ -581,6 +582,19 @@ def delete_model_mock():
yield delete_model_mock


@pytest.fixture
def copy_model_mock():
with mock.patch.object(
model_service_client.ModelServiceClient, "copy_model"
) as copy_model_mock:
mock_lro = mock.Mock(ga_operation.Operation)
mock_lro.result.return_value = gca_model_service.CopyModelResponse(
model=_TEST_MODEL_RESOURCE_NAME_CUSTOM_LOCATION
)
copy_model_mock.return_value = mock_lro
yield copy_model_mock


@pytest.fixture
def deploy_model_mock():
with mock.patch.object(
Expand Down Expand Up @@ -2419,6 +2433,71 @@ def test_upload_tensorflow_saved_model_uploads_and_gets_model(
staged_model_file_name = staged_model_file_path.split("/")[-1]
assert staged_model_file_name in ["saved_model.pb", "saved_model.pbtxt"]

def test_copy_as_new_model(self, copy_model_mock, get_model_mock):

test_model = models.Model(_TEST_ID)
test_model.copy(destination_location=_TEST_LOCATION_2)

copy_model_mock.assert_called_once_with(
request=gca_model_service.CopyModelRequest(
parent=initializer.global_config.common_location_path(
location=_TEST_LOCATION_2
),
source_model=_TEST_MODEL_RESOURCE_NAME,
),
timeout=None,
)

def test_copy_as_new_version(self, copy_model_mock, get_model_mock):
test_model = models.Model(_TEST_ID)
test_model.copy(
destination_location=_TEST_LOCATION_2,
destination_parent_model=_TEST_MODEL_NAME_ALT,
)

copy_model_mock.assert_called_once_with(
request=gca_model_service.CopyModelRequest(
parent=initializer.global_config.common_location_path(
location=_TEST_LOCATION_2
),
source_model=_TEST_MODEL_RESOURCE_NAME,
parent_model=model_service_client.ModelServiceClient.model_path(
_TEST_PROJECT, _TEST_LOCATION_2, _TEST_MODEL_NAME_ALT
),
),
timeout=None,
)

def test_copy_as_new_model_custom_id(self, copy_model_mock, get_model_mock):
test_model = models.Model(_TEST_ID)
test_model.copy(
destination_location=_TEST_LOCATION_2, destination_model_id=_TEST_MODEL_ID
)

copy_model_mock.assert_called_once_with(
request=gca_model_service.CopyModelRequest(
parent=initializer.global_config.common_location_path(
location=_TEST_LOCATION_2
),
source_model=_TEST_MODEL_RESOURCE_NAME,
model_id=_TEST_MODEL_ID,
),
timeout=None,
)

def test_copy_with_invalid_params(self, copy_model_mock, get_model_mock):
with pytest.raises(ValueError) as e:
test_model = models.Model(_TEST_ID)
test_model.copy(
destination_location=_TEST_LOCATION,
destination_model_id=_TEST_MODEL_ID,
destination_parent_model=_TEST_MODEL_RESOURCE_NAME,
)

assert e.match(
regexp=r"`destination_model_id` and `destination_parent_model` can not be set together."
)

@pytest.mark.usefixtures("get_model_mock")
def test_update(self, update_model_mock, get_model_mock):

Expand Down

0 comments on commit 94dd82f

Please sign in to comment.