diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 60bb798166..608beecd4d 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -4656,6 +4656,125 @@ def upload_tensorflow_saved_model( upload_request_timeout=upload_request_timeout, ) + # TODO(b/273499620): Add async support. + def copy( + self, + destination_location: str, + destination_model_id: Optional[str] = None, + destination_parent_model: Optional[str] = None, + encryption_spec_key_name: Optional[str] = None, + copy_request_timeout: Optional[float] = None, + ) -> "Model": + """Copys a model and returns a Model representing the copied Model + resource. This method is a blocking call. + + Example usage: + copied_model = my_model.copy( + destination_location="us-central1" + ) + + Args: + destination_location (str): + The destination location to copy the model to. + destination_model_id (str): + Optional. The ID to use for the copied Model, which will + become the final component of the model resource name. + This value may be up to 63 characters, and valid characters + are `[a-z0-9_-]`. The first character cannot be a number or hyphen. + + Only set this field when copying as a new model. If this field is not set, + a numeric model id will be generated. + destination_parent_model (str): + Optional. The resource name or model ID of an existing model that the + newly-copied model will be a version of. + + Only set this field when copying as a new version of an existing model. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this Model and all sub-resources of this Model will be secured by this key. + + Overrides encryption_spec_key_name set in aiplatform.init. + copy_request_timeout (float): + Optional. The timeout for the copy request in seconds. + + Returns: + model (aiplatform.Model): + Instantiated representation of the copied model resource. + + Raises: + ValueError: If both `destination_model_id` and `destination_parent_model` are set. + """ + if destination_model_id is not None and destination_parent_model is not None: + raise ValueError( + "`destination_model_id` and `destination_parent_model` can not be set together." + ) + + parent = initializer.global_config.common_location_path( + initializer.global_config.project, destination_location + ) + + source_model = self.versioned_resource_name + + destination_parent_model = ModelRegistry._get_true_version_parent( + parent_model=destination_parent_model, + project=initializer.global_config.project, + location=destination_location, + ) + + encryption_spec = initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name, + ) + + if destination_model_id is not None: + request = gca_model_service_compat.CopyModelRequest( + parent=parent, + source_model=source_model, + model_id=destination_model_id, + encryption_spec=encryption_spec, + ) + else: + request = gca_model_service_compat.CopyModelRequest( + parent=parent, + source_model=source_model, + parent_model=destination_parent_model, + encryption_spec=encryption_spec, + ) + + api_client = initializer.global_config.create_client( + client_class=utils.ModelClientWithOverride, + location_override=destination_location, + credentials=initializer.global_config.credentials, + ) + + _LOGGER.log_action_start_against_resource("Copying", "", self) + + lro = api_client.copy_model( + request=request, + timeout=copy_request_timeout, + ) + + _LOGGER.log_action_started_against_resource_with_lro( + "Copy", "", self.__class__, lro + ) + + model_copy_response = lro.result(timeout=None) + + this_model = models.Model( + model_copy_response.model, + version=model_copy_response.model_version_id, + location=destination_location, + ) + + _LOGGER.log_action_completed_against_resource("", "copied", this_model) + + return this_model + def list_model_evaluations( self, ) -> List["model_evaluation.ModelEvaluation"]: diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index 7df01ecd82..a11a44917d 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -72,6 +72,7 @@ _TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" _TEST_MODEL_NAME = "123" _TEST_MODEL_NAME_ALT = "456" +_TEST_MODEL_ID = "my-model" _TEST_MODEL_PARENT = ( f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_MODEL_NAME}" ) @@ -581,6 +582,19 @@ def delete_model_mock(): yield delete_model_mock +@pytest.fixture +def copy_model_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "copy_model" + ) as copy_model_mock: + mock_lro = mock.Mock(ga_operation.Operation) + mock_lro.result.return_value = gca_model_service.CopyModelResponse( + model=_TEST_MODEL_RESOURCE_NAME_CUSTOM_LOCATION + ) + copy_model_mock.return_value = mock_lro + yield copy_model_mock + + @pytest.fixture def deploy_model_mock(): with mock.patch.object( @@ -2419,6 +2433,71 @@ def test_upload_tensorflow_saved_model_uploads_and_gets_model( staged_model_file_name = staged_model_file_path.split("/")[-1] assert staged_model_file_name in ["saved_model.pb", "saved_model.pbtxt"] + def test_copy_as_new_model(self, copy_model_mock, get_model_mock): + + test_model = models.Model(_TEST_ID) + test_model.copy(destination_location=_TEST_LOCATION_2) + + copy_model_mock.assert_called_once_with( + request=gca_model_service.CopyModelRequest( + parent=initializer.global_config.common_location_path( + location=_TEST_LOCATION_2 + ), + source_model=_TEST_MODEL_RESOURCE_NAME, + ), + timeout=None, + ) + + def test_copy_as_new_version(self, copy_model_mock, get_model_mock): + test_model = models.Model(_TEST_ID) + test_model.copy( + destination_location=_TEST_LOCATION_2, + destination_parent_model=_TEST_MODEL_NAME_ALT, + ) + + copy_model_mock.assert_called_once_with( + request=gca_model_service.CopyModelRequest( + parent=initializer.global_config.common_location_path( + location=_TEST_LOCATION_2 + ), + source_model=_TEST_MODEL_RESOURCE_NAME, + parent_model=model_service_client.ModelServiceClient.model_path( + _TEST_PROJECT, _TEST_LOCATION_2, _TEST_MODEL_NAME_ALT + ), + ), + timeout=None, + ) + + def test_copy_as_new_model_custom_id(self, copy_model_mock, get_model_mock): + test_model = models.Model(_TEST_ID) + test_model.copy( + destination_location=_TEST_LOCATION_2, destination_model_id=_TEST_MODEL_ID + ) + + copy_model_mock.assert_called_once_with( + request=gca_model_service.CopyModelRequest( + parent=initializer.global_config.common_location_path( + location=_TEST_LOCATION_2 + ), + source_model=_TEST_MODEL_RESOURCE_NAME, + model_id=_TEST_MODEL_ID, + ), + timeout=None, + ) + + def test_copy_with_invalid_params(self, copy_model_mock, get_model_mock): + with pytest.raises(ValueError) as e: + test_model = models.Model(_TEST_ID) + test_model.copy( + destination_location=_TEST_LOCATION, + destination_model_id=_TEST_MODEL_ID, + destination_parent_model=_TEST_MODEL_RESOURCE_NAME, + ) + + assert e.match( + regexp=r"`destination_model_id` and `destination_parent_model` can not be set together." + ) + @pytest.mark.usefixtures("get_model_mock") def test_update(self, update_model_mock, get_model_mock):