Skip to content

Commit

Permalink
fix: PipelineJob should only pass bearer tokens for AR URIs (#1717)
Browse files Browse the repository at this point in the history
When downloading compiled KFP pipelines over HTTPS, we only need to pass a bearer token when we need to authenticate for services like Artifact Registry. We may get unexpected behavior passing this token in all HTTPS requests, which is the current behavior.

Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
- [x] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-aiplatform/issues/new/choose) before writing your code!  That way we can discuss the change, evaluate designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)

Fixes b/251143831 🦕
  • Loading branch information
TheMichaelHu committed Oct 7, 2022
1 parent dde9ba1 commit b43851c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
4 changes: 3 additions & 1 deletion google/cloud/aiplatform/utils/yaml_utils.py
Expand Up @@ -52,8 +52,10 @@ def load_yaml(
if path.startswith("gs://"):
return _load_yaml_from_gs_uri(path, project, credentials)
elif path.startswith("http://") or path.startswith("https://"):
if _VALID_AR_URL.match(path) or _VALID_HTTPS_URL.match(path):
if _VALID_AR_URL.match(path):
return _load_yaml_from_https_uri(path, credentials)
elif _VALID_HTTPS_URL.match(path):
return _load_yaml_from_https_uri(path)
else:
raise ValueError(
"Invalid HTTPS URI. If not using Artifact Registry, please "
Expand Down
27 changes: 20 additions & 7 deletions tests/unit/aiplatform/test_utils.py
Expand Up @@ -21,14 +21,15 @@
import json
import os
import textwrap
from typing import Callable, Dict, Optional
from typing import Callable, Dict, Optional, Tuple
from unittest import mock
from unittest.mock import patch
from urllib import request as urllib_request

import pytest
import yaml
from google.api_core import client_options, gapic_v1
from google.auth import credentials
from google.cloud import aiplatform
from google.cloud import storage
from google.cloud.aiplatform import compat, utils
Expand Down Expand Up @@ -775,15 +776,15 @@ def json_file(tmp_path):


@pytest.fixture(scope="function")
def mock_request_urlopen(request: str) -> str:
def mock_request_urlopen(request: str) -> Tuple[str, mock.MagicMock]:
data = {"key": "val", "list": ["1", 2, 3.0]}
with mock.patch.object(urllib_request, "urlopen") as mock_urlopen:
mock_read_response = mock.MagicMock()
mock_decode_response = mock.MagicMock()
mock_decode_response.return_value = json.dumps(data)
mock_read_response.return_value.decode = mock_decode_response
mock_urlopen.return_value.read = mock_read_response
yield request.param
yield request.param, mock_urlopen


class TestYamlUtils:
Expand All @@ -802,10 +803,17 @@ def test_load_yaml_from_local_file__with_json(self, json_file):
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
indirect=True,
)
def test_load_yaml_from_ar_uri(self, mock_request_urlopen):
actual = yaml_utils.load_yaml(mock_request_urlopen)
def test_load_yaml_from_ar_uri_passes_creds(self, mock_request_urlopen):
url, mock_urlopen = mock_request_urlopen
mock_credentials = mock.create_autospec(credentials.Credentials, instance=True)
mock_credentials.valid = True
mock_credentials.token = "some_token"
actual = yaml_utils.load_yaml(url, credentials=mock_credentials)
expected = {"key": "val", "list": ["1", 2, 3.0]}
assert actual == expected
assert mock_urlopen.call_args[0][0].headers == {
"Authorization": "Bearer some_token"
}

@pytest.mark.parametrize(
"mock_request_urlopen",
Expand All @@ -816,10 +824,15 @@ def test_load_yaml_from_ar_uri(self, mock_request_urlopen):
],
indirect=True,
)
def test_load_yaml_from_https_uri(self, mock_request_urlopen):
actual = yaml_utils.load_yaml(mock_request_urlopen)
def test_load_yaml_from_https_uri_ignores_creds(self, mock_request_urlopen):
url, mock_urlopen = mock_request_urlopen
mock_credentials = mock.create_autospec(credentials.Credentials, instance=True)
mock_credentials.valid = True
mock_credentials.token = "some_token"
actual = yaml_utils.load_yaml(url, credentials=mock_credentials)
expected = {"key": "val", "list": ["1", 2, 3.0]}
assert actual == expected
assert mock_urlopen.call_args[0][0].headers == {}

@pytest.mark.parametrize(
"uri",
Expand Down

0 comments on commit b43851c

Please sign in to comment.