Skip to content

Commit

Permalink
Implement CloudDataTransferServiceRunJobOperator (apache#39154)
Browse files Browse the repository at this point in the history
* Implement CloudDataTransferServiceRunJobOperator

* Add references from storage-transfer docs

* Add unit test for `run_transfer_job`

* Add docs and example dag usage for `CloudDataTransferServiceRunJobOperator`

* Fix doctest errors

* Validate inputs inside execute function

* Remove validation check in the constructor

* Fix failing test

* Fix ruff linter issues

* Ensure consistent `project_id` usage

Co-authored-by: Josh Fell <48934154+josh-fell@users.noreply.github.com>

---------

Co-authored-by: Josh Fell <48934154+josh-fell@users.noreply.github.com>
  • Loading branch information
2 people authored and fdemiane committed Jun 6, 2024
1 parent 7250b5c commit 4951385
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,32 @@ def delete_transfer_job(self, job_name: str, project_id: str) -> None:
.execute(num_retries=self.num_retries)
)

@GoogleBaseHook.fallback_to_default_project_id
def run_transfer_job(self, job_name: str, project_id: str) -> dict:
"""Run Google Storage Transfer Service job.
:param job_name: (Required) Name of the job to be fetched
:param project_id: (Optional) the ID of the project that owns the Transfer
Job. If set to None or missing, the default project_id from the Google Cloud
connection is used.
:return: If successful, Operation. See:
https://cloud.google.com/storage-transfer/docs/reference/rest/v1/Operation
.. seealso:: https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/run
"""
return (
self.get_conn()
.transferJobs()
.run(
jobName=job_name,
body={
PROJECT_ID: project_id,
},
)
.execute(num_retries=self.num_retries)
)

def cancel_transfer_operation(self, operation_name: str) -> None:
"""Cancel a transfer operation in Google Storage Transfer Service.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,82 @@ def execute(self, context: Context) -> None:
hook.delete_transfer_job(job_name=self.job_name, project_id=self.project_id)


class CloudDataTransferServiceRunJobOperator(GoogleCloudBaseOperator):
"""
Runs a transfer job.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:CloudDataTransferServiceRunJobOperator`
:param job_name: (Required) Name of the job to be run
:param project_id: (Optional) the ID of the project that owns the Transfer
Job. If set to None or missing, the default project_id from the Google Cloud
connection is used.
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
:param api_version: API version used (e.g. v1).
:param google_impersonation_chain: Optional Google service account to impersonate using
short-term credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

# [START gcp_transfer_job_run_template_fields]
template_fields: Sequence[str] = (
"job_name",
"project_id",
"gcp_conn_id",
"api_version",
"google_impersonation_chain",
)
# [END gcp_transfer_job_run_template_fields]
operator_extra_links = (CloudStorageTransferJobLink(),)

def __init__(
self,
*,
job_name: str,
gcp_conn_id: str = "google_cloud_default",
api_version: str = "v1",
project_id: str = PROVIDE_PROJECT_ID,
google_impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.job_name = job_name
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.api_version = api_version
self.google_impersonation_chain = google_impersonation_chain

def _validate_inputs(self) -> None:
if not self.job_name:
raise AirflowException("The required parameter 'job_name' is empty or None")

def execute(self, context: Context) -> dict:
self._validate_inputs()
hook = CloudDataTransferServiceHook(
api_version=self.api_version,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.google_impersonation_chain,
)

project_id = self.project_id or hook.project_id
if project_id:
CloudStorageTransferJobLink.persist(
context=context,
task_instance=self,
project_id=project_id,
job_name=self.job_name,
)

return hook.run_transfer_job(job_name=self.job_name, project_id=project_id)


class CloudDataTransferServiceGetOperationOperator(GoogleCloudBaseOperator):
"""
Gets the latest state of a long-running operation in Google Storage Transfer Service.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,41 @@ See `Google Cloud Transfer Service - REST Resource: transferJobs - Status

.. _howto/operator:CloudDataTransferServiceUpdateJobOperator:

CloudDataTransferServiceRunJobOperator
-----------------------------------------

Runs a transfer job.

For parameter definition, take a look at
:class:`~airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceRunJobOperator`.


Using the operator
""""""""""""""""""

.. exampleinclude:: /../../tests/system/providers/google/cloud/storage_transfer/example_cloud_storage_transfer_service_gcp.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gcp_transfer_run_job]
:end-before: [END howto_operator_gcp_transfer_run_job]

Templating
""""""""""

.. literalinclude:: /../../airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py
:language: python
:dedent: 4
:start-after: [START gcp_transfer_job_run_template_fields]
:end-before: [END gcp_transfer_job_run_template_fields]

More information
""""""""""""""""

See `Google Cloud Transfer Service - REST Resource: transferJobs - Run
<https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/run>`_

.. _howto/operator:CloudDataTransferServiceRunJobOperator:

CloudDataTransferServiceUpdateJobOperator
-----------------------------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,25 @@ def test_delete_transfer_job(self, get_conn):
)
execute_method.assert_called_once_with(num_retries=5)

@mock.patch(
"airflow.providers.google.cloud.hooks.cloud_storage_transfer_service."
"CloudDataTransferServiceHook.get_conn"
)
def test_run_transfer_job(self, get_conn):
run_method = get_conn.return_value.transferJobs.return_value.run
execute_method = run_method.return_value.execute
execute_method.return_value = TEST_TRANSFER_OPERATION

res = self.gct_hook.run_transfer_job(job_name=TEST_TRANSFER_JOB_NAME, project_id=TEST_PROJECT_ID)
assert res == TEST_TRANSFER_OPERATION
run_method.assert_called_once_with(
jobName=TEST_TRANSFER_JOB_NAME,
body={
PROJECT_ID: TEST_PROJECT_ID,
},
)
execute_method.assert_called_once_with(num_retries=5)

@mock.patch(
"airflow.providers.google.cloud.hooks.cloud_storage_transfer_service"
".CloudDataTransferServiceHook.get_conn"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
CloudDataTransferServiceListOperationsOperator,
CloudDataTransferServicePauseOperationOperator,
CloudDataTransferServiceResumeOperationOperator,
CloudDataTransferServiceRunJobOperator,
CloudDataTransferServiceS3ToGCSOperator,
CloudDataTransferServiceUpdateJobOperator,
TransferJobPreprocessor,
Expand Down Expand Up @@ -493,6 +494,61 @@ def test_job_delete_should_throw_ex_when_name_none(self):
CloudDataTransferServiceDeleteJobOperator(job_name="", task_id="task-id")


class TestGcpStorageTransferJobRunOperator:
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
def test_job_run(self, mock_hook):
mock_hook.return_value.run_transfer_job.return_value = VALID_OPERATION
op = CloudDataTransferServiceRunJobOperator(
job_name=JOB_NAME,
project_id=GCP_PROJECT_ID,
task_id="task-id",
google_impersonation_chain=IMPERSONATION_CHAIN,
)
result = op.execute(context=mock.MagicMock())
mock_hook.assert_called_once_with(
api_version="v1",
gcp_conn_id="google_cloud_default",
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.return_value.run_transfer_job.assert_called_once_with(
job_name=JOB_NAME, project_id=GCP_PROJECT_ID
)
assert result == VALID_OPERATION

# Setting all the operator's input parameters as templated dag_ids
# (could be anything else) just to test if the templating works for all
# fields
@pytest.mark.db_test
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
def test_job_run_with_templates(self, _, create_task_instance_of_operator):
dag_id = "test_job_run_with_templates"
ti = create_task_instance_of_operator(
CloudDataTransferServiceRunJobOperator,
dag_id=dag_id,
job_name="{{ dag.dag_id }}",
project_id="{{ dag.dag_id }}",
gcp_conn_id="{{ dag.dag_id }}",
api_version="{{ dag.dag_id }}",
google_impersonation_chain="{{ dag.dag_id }}",
task_id=TASK_ID,
)
ti.render_templates()
assert dag_id == ti.task.job_name
assert dag_id == ti.task.project_id
assert dag_id == ti.task.gcp_conn_id
assert dag_id == ti.task.api_version
assert dag_id == ti.task.google_impersonation_chain

def test_job_run_should_throw_ex_when_name_none(self):
op = CloudDataTransferServiceRunJobOperator(job_name="", task_id="task-id")
with pytest.raises(AirflowException, match="The required parameter 'job_name' is empty or None"):
op.execute(context=mock.MagicMock())


class TestGpcStorageTransferOperationsGetOperator:
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
CloudDataTransferServiceDeleteJobOperator,
CloudDataTransferServiceGetOperationOperator,
CloudDataTransferServiceListOperationsOperator,
CloudDataTransferServiceRunJobOperator,
CloudDataTransferServiceUpdateJobOperator,
)
from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator
Expand Down Expand Up @@ -147,6 +148,14 @@
expected_statuses={GcpTransferOperationStatus.SUCCESS},
)

# [START howto_operator_gcp_transfer_run_job]
run_transfer = CloudDataTransferServiceRunJobOperator(
task_id="run_transfer",
job_name="{{task_instance.xcom_pull('create_transfer')['name']}}",
project_id=PROJECT_ID_TRANSFER,
)
# [END howto_operator_gcp_transfer_run_job]

list_operations = CloudDataTransferServiceListOperationsOperator(
task_id="list_operations",
request_filter={
Expand Down Expand Up @@ -180,6 +189,7 @@
>> create_transfer
>> wait_for_transfer
>> update_transfer
>> run_transfer
>> list_operations
>> get_operation
>> [delete_transfer, delete_bucket_src, delete_bucket_dst]
Expand Down

0 comments on commit 4951385

Please sign in to comment.