Skip to content
Permalink
Browse files
feat: add mtls support to client (#492)
* feat: add mtls feature
  • Loading branch information
arithmetic1728 committed Feb 5, 2021
1 parent 3138d41 commit 1823cadee3acf95c516d0479400e4175349ea199
Showing with 79 additions and 12 deletions.
  1. +19 −2 google/cloud/bigquery/_http.py
  2. +19 −6 google/cloud/bigquery/client.py
  3. +6 −0 tests/system/test_client.py
  4. +2 −0 tests/unit/helpers.py
  5. +14 −0 tests/unit/test__http.py
  6. +19 −4 tests/unit/test_client.py
@@ -14,25 +14,42 @@

"""Create / interact with Google BigQuery connections."""

import os
import pkg_resources

from google.cloud import _http

from google.cloud.bigquery import __version__


# TODO: Increase the minimum version of google-cloud-core to 1.6.0
# and remove this logic. See:
# https://github.com/googleapis/python-bigquery/issues/509
if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true": # pragma: NO COVER
release = pkg_resources.get_distribution("google-cloud-core").parsed_version
if release < pkg_resources.parse_version("1.6.0"):
raise ImportError("google-cloud-core >= 1.6.0 is required to use mTLS feature")


class Connection(_http.JSONConnection):
"""A connection to Google BigQuery via the JSON REST API.
Args:
client (google.cloud.bigquery.client.Client): The client that owns the current connection.
client_info (Optional[google.api_core.client_info.ClientInfo]): Instance used to generate user agent.
api_endpoint (str): The api_endpoint to use. If None, the library will decide what endpoint to use.
"""

DEFAULT_API_ENDPOINT = "https://bigquery.googleapis.com"
DEFAULT_API_MTLS_ENDPOINT = "https://bigquery.mtls.googleapis.com"

def __init__(self, client, client_info=None, api_endpoint=DEFAULT_API_ENDPOINT):
def __init__(self, client, client_info=None, api_endpoint=None):
super(Connection, self).__init__(client, client_info)
self.API_BASE_URL = api_endpoint
self.API_BASE_URL = api_endpoint or self.DEFAULT_API_ENDPOINT
self.API_BASE_MTLS_URL = self.DEFAULT_API_MTLS_ENDPOINT
self.ALLOW_AUTO_SWITCH_TO_MTLS_URL = api_endpoint is None
self._client_info.gapic_version = __version__
self._client_info.client_library_version = __version__

@@ -78,10 +78,7 @@
_DEFAULT_CHUNKSIZE = 1048576 # 1024 * 1024 B = 1 MB
_MAX_MULTIPART_SIZE = 5 * 1024 * 1024
_DEFAULT_NUM_RETRIES = 6
_BASE_UPLOAD_TEMPLATE = (
"https://bigquery.googleapis.com/upload/bigquery/v2/projects/"
"{project}/jobs?uploadType="
)
_BASE_UPLOAD_TEMPLATE = "{host}/upload/bigquery/v2/projects/{project}/jobs?uploadType="
_MULTIPART_URL_TEMPLATE = _BASE_UPLOAD_TEMPLATE + "multipart"
_RESUMABLE_URL_TEMPLATE = _BASE_UPLOAD_TEMPLATE + "resumable"
_GENERIC_CONTENT_TYPE = "*/*"
@@ -2547,7 +2544,15 @@ def _initiate_resumable_upload(

if project is None:
project = self.project
upload_url = _RESUMABLE_URL_TEMPLATE.format(project=project)
# TODO: Increase the minimum version of google-cloud-core to 1.6.0
# and remove this logic. See:
# https://github.com/googleapis/python-bigquery/issues/509
hostname = (
self._connection.API_BASE_URL
if not hasattr(self._connection, "get_api_base_url_for_mtls")
else self._connection.get_api_base_url_for_mtls()
)
upload_url = _RESUMABLE_URL_TEMPLATE.format(host=hostname, project=project)

# TODO: modify ResumableUpload to take a retry.Retry object
# that it can use for the initial RPC.
@@ -2616,7 +2621,15 @@ def _do_multipart_upload(
if project is None:
project = self.project

upload_url = _MULTIPART_URL_TEMPLATE.format(project=project)
# TODO: Increase the minimum version of google-cloud-core to 1.6.0
# and remove this logic. See:
# https://github.com/googleapis/python-bigquery/issues/509
hostname = (
self._connection.API_BASE_URL
if not hasattr(self._connection, "get_api_base_url_for_mtls")
else self._connection.get_api_base_url_for_mtls()
)
upload_url = _MULTIPART_URL_TEMPLATE.format(host=hostname, project=project)
upload = MultipartUpload(upload_url, headers=headers)

if num_retries is not None:
@@ -28,6 +28,7 @@
import uuid

import psutil
import pytest
import pytz
import pkg_resources

@@ -132,6 +133,8 @@
else:
PYARROW_INSTALLED_VERSION = None

MTLS_TESTING = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true"


def _has_rows(result):
return len(result) > 0
@@ -2651,6 +2654,9 @@ def test_insert_rows_nested_nested_dictionary(self):
expected_rows = [("Some value", record)]
self.assertEqual(row_tuples, expected_rows)

@pytest.mark.skipif(
MTLS_TESTING, reason="mTLS testing has no permission to the max-value.js file"
)
def test_create_routine(self):
routine_name = "test_routine"
dataset = self.temp_dataset(_make_dataset_id("create_routine"))
@@ -21,6 +21,8 @@ def make_connection(*responses):
mock_conn = mock.create_autospec(google.cloud.bigquery._http.Connection)
mock_conn.user_agent = "testing 1.2.3"
mock_conn.api_request.side_effect = list(responses) + [NotFound("miss")]
mock_conn.API_BASE_URL = "https://bigquery.googleapis.com"
mock_conn.get_api_base_url_for_mtls = mock.Mock(return_value=mock_conn.API_BASE_URL)
return mock_conn


@@ -32,6 +32,9 @@ def _get_target_class():
return Connection

def _make_one(self, *args, **kw):
if "api_endpoint" not in kw:
kw["api_endpoint"] = "https://bigquery.googleapis.com"

return self._get_target_class()(*args, **kw)

def test_build_api_url_no_extra_query_params(self):
@@ -138,3 +141,14 @@ def test_extra_headers_replace(self):
url=expected_uri,
timeout=self._get_default_timeout(),
)

def test_ctor_mtls(self):
conn = self._make_one(object(), api_endpoint=None)
self.assertEqual(conn.ALLOW_AUTO_SWITCH_TO_MTLS_URL, True)
self.assertEqual(conn.API_BASE_URL, "https://bigquery.googleapis.com")
self.assertEqual(conn.API_BASE_MTLS_URL, "https://bigquery.mtls.googleapis.com")

conn = self._make_one(object(), api_endpoint="http://foo")
self.assertEqual(conn.ALLOW_AUTO_SWITCH_TO_MTLS_URL, False)
self.assertEqual(conn.API_BASE_URL, "http://foo")
self.assertEqual(conn.API_BASE_MTLS_URL, "https://bigquery.mtls.googleapis.com")
@@ -2057,6 +2057,7 @@ def test_get_table_sets_user_agent(self):
url=mock.ANY, method=mock.ANY, headers=mock.ANY, data=mock.ANY
)
http.reset_mock()
http.is_mtls = False
mock_response.status_code = 200
mock_response.json.return_value = self._make_table_resource()
user_agent_override = client_info.ClientInfo(user_agent="my-application/1.2.3")
@@ -4425,7 +4426,7 @@ def _mock_transport(self, status_code, headers, content=b""):
fake_transport.request.return_value = fake_response
return fake_transport

def _initiate_resumable_upload_helper(self, num_retries=None):
def _initiate_resumable_upload_helper(self, num_retries=None, mtls=False):
from google.resumable_media.requests import ResumableUpload
from google.cloud.bigquery.client import _DEFAULT_CHUNKSIZE
from google.cloud.bigquery.client import _GENERIC_CONTENT_TYPE
@@ -4440,6 +4441,8 @@ def _initiate_resumable_upload_helper(self, num_retries=None):
fake_transport = self._mock_transport(http.client.OK, response_headers)
client = self._make_one(project=self.PROJECT, _http=fake_transport)
conn = client._connection = make_connection()
if mtls:
conn.get_api_base_url_for_mtls = mock.Mock(return_value="https://foo.mtls")

# Create some mock arguments and call the method under test.
data = b"goodbye gudbi gootbee"
@@ -4454,8 +4457,10 @@ def _initiate_resumable_upload_helper(self, num_retries=None):

# Check the returned values.
self.assertIsInstance(upload, ResumableUpload)

host_name = "https://foo.mtls" if mtls else "https://bigquery.googleapis.com"
upload_url = (
f"https://bigquery.googleapis.com/upload/bigquery/v2/projects/{self.PROJECT}"
f"{host_name}/upload/bigquery/v2/projects/{self.PROJECT}"
"/jobs?uploadType=resumable"
)
self.assertEqual(upload.upload_url, upload_url)
@@ -4494,11 +4499,14 @@ def _initiate_resumable_upload_helper(self, num_retries=None):
def test__initiate_resumable_upload(self):
self._initiate_resumable_upload_helper()

def test__initiate_resumable_upload_mtls(self):
self._initiate_resumable_upload_helper(mtls=True)

def test__initiate_resumable_upload_with_retry(self):
self._initiate_resumable_upload_helper(num_retries=11)

def _do_multipart_upload_success_helper(
self, get_boundary, num_retries=None, project=None
self, get_boundary, num_retries=None, project=None, mtls=False
):
from google.cloud.bigquery.client import _get_upload_headers
from google.cloud.bigquery.job import LoadJob
@@ -4508,6 +4516,8 @@ def _do_multipart_upload_success_helper(
fake_transport = self._mock_transport(http.client.OK, {})
client = self._make_one(project=self.PROJECT, _http=fake_transport)
conn = client._connection = make_connection()
if mtls:
conn.get_api_base_url_for_mtls = mock.Mock(return_value="https://foo.mtls")

if project is None:
project = self.PROJECT
@@ -4530,8 +4540,9 @@ def _do_multipart_upload_success_helper(
self.assertEqual(stream.tell(), size)
get_boundary.assert_called_once_with()

host_name = "https://foo.mtls" if mtls else "https://bigquery.googleapis.com"
upload_url = (
f"https://bigquery.googleapis.com/upload/bigquery/v2/projects/{project}"
f"{host_name}/upload/bigquery/v2/projects/{project}"
"/jobs?uploadType=multipart"
)
payload = (
@@ -4556,6 +4567,10 @@ def _do_multipart_upload_success_helper(
def test__do_multipart_upload(self, get_boundary):
self._do_multipart_upload_success_helper(get_boundary)

@mock.patch("google.resumable_media._upload.get_boundary", return_value=b"==0==")
def test__do_multipart_upload_mtls(self, get_boundary):
self._do_multipart_upload_success_helper(get_boundary, mtls=True)

@mock.patch("google.resumable_media._upload.get_boundary", return_value=b"==0==")
def test__do_multipart_upload_with_retry(self, get_boundary):
self._do_multipart_upload_success_helper(get_boundary, num_retries=8)

0 comments on commit 1823cad

Please sign in to comment.