From 83bc768b90a852d258a4805603020a296e02d2f9 Mon Sep 17 00:00:00 2001 From: Gaurang Shah Date: Sat, 28 Oct 2023 02:29:08 -0400 Subject: [PATCH] feat: add support for dataset.default_rounding_mode (#1688) Co-authored-by: Lingqing Gan --- google/cloud/bigquery/dataset.py | 38 +++++++++++ tests/system/test_client.py | 16 ++++- tests/unit/test_create_dataset.py | 103 +++++++++++++++++++++++++++++- 3 files changed, 153 insertions(+), 4 deletions(-) diff --git a/google/cloud/bigquery/dataset.py b/google/cloud/bigquery/dataset.py index 0f1a0f3cc..af94784a4 100644 --- a/google/cloud/bigquery/dataset.py +++ b/google/cloud/bigquery/dataset.py @@ -525,6 +525,7 @@ class Dataset(object): "friendly_name": "friendlyName", "default_encryption_configuration": "defaultEncryptionConfiguration", "storage_billing_model": "storageBillingModel", + "default_rounding_mode": "defaultRoundingMode", } def __init__(self, dataset_ref) -> None: @@ -532,6 +533,43 @@ def __init__(self, dataset_ref) -> None: dataset_ref = DatasetReference.from_string(dataset_ref) self._properties = {"datasetReference": dataset_ref.to_api_repr(), "labels": {}} + @property + def default_rounding_mode(self): + """Union[str, None]: defaultRoundingMode of the dataset as set by the user + (defaults to :data:`None`). + + Set the value to one of ``'ROUND_HALF_AWAY_FROM_ZERO'``, ``'ROUND_HALF_EVEN'``, or + ``'ROUNDING_MODE_UNSPECIFIED'``. + + See `default rounding mode + `_ + in REST API docs and `updating the default rounding model + `_ + guide. + + Raises: + ValueError: for invalid value types. + """ + return self._properties.get("defaultRoundingMode") + + @default_rounding_mode.setter + def default_rounding_mode(self, value): + possible_values = [ + "ROUNDING_MODE_UNSPECIFIED", + "ROUND_HALF_AWAY_FROM_ZERO", + "ROUND_HALF_EVEN", + ] + if not isinstance(value, str) and value is not None: + raise ValueError("Pass a string, or None") + if value is None: + self._properties["defaultRoundingMode"] = "ROUNDING_MODE_UNSPECIFIED" + if value not in possible_values and value is not None: + raise ValueError( + f'rounding mode needs to be one of {",".join(possible_values)}' + ) + if value: + self._properties["defaultRoundingMode"] = value + @property def project(self): """str: Project ID of the project bound to the dataset.""" diff --git a/tests/system/test_client.py b/tests/system/test_client.py index d3b95ec49..09606590e 100644 --- a/tests/system/test_client.py +++ b/tests/system/test_client.py @@ -265,6 +265,13 @@ def test_get_dataset(self): self.assertEqual(got.friendly_name, "Friendly") self.assertEqual(got.description, "Description") + def test_create_dataset_with_default_rounding_mode(self): + DATASET_ID = _make_dataset_id("create_dataset_rounding_mode") + dataset = self.temp_dataset(DATASET_ID, default_rounding_mode="ROUND_HALF_EVEN") + + self.assertTrue(_dataset_exists(dataset)) + self.assertEqual(dataset.default_rounding_mode, "ROUND_HALF_EVEN") + def test_update_dataset(self): dataset = self.temp_dataset(_make_dataset_id("update_dataset")) self.assertTrue(_dataset_exists(dataset)) @@ -2286,12 +2293,15 @@ def test_nested_table_to_arrow(self): self.assertTrue(pyarrow.types.is_list(record_col[1].type)) self.assertTrue(pyarrow.types.is_int64(record_col[1].type.value_type)) - def temp_dataset(self, dataset_id, location=None): + def temp_dataset(self, dataset_id, *args, **kwargs): project = Config.CLIENT.project dataset_ref = bigquery.DatasetReference(project, dataset_id) dataset = Dataset(dataset_ref) - if location: - dataset.location = location + if kwargs.get("location"): + dataset.location = kwargs.get("location") + if kwargs.get("default_rounding_mode"): + dataset.default_rounding_mode = kwargs.get("default_rounding_mode") + dataset = helpers.retry_403(Config.CLIENT.create_dataset)(dataset) self.to_delete.append(dataset) return dataset diff --git a/tests/unit/test_create_dataset.py b/tests/unit/test_create_dataset.py index 81af52261..3b2e644d9 100644 --- a/tests/unit/test_create_dataset.py +++ b/tests/unit/test_create_dataset.py @@ -63,6 +63,7 @@ def test_create_dataset_w_attrs(client, PROJECT, DS_ID): "datasetId": "starry-skies", "tableId": "northern-hemisphere", } + DEFAULT_ROUNDING_MODE = "ROUND_HALF_EVEN" RESOURCE = { "datasetReference": {"projectId": PROJECT, "datasetId": DS_ID}, "etag": "etag", @@ -73,6 +74,7 @@ def test_create_dataset_w_attrs(client, PROJECT, DS_ID): "defaultTableExpirationMs": "3600", "labels": LABELS, "access": [{"role": "OWNER", "userByEmail": USER_EMAIL}, {"view": VIEW}], + "defaultRoundingMode": DEFAULT_ROUNDING_MODE, } conn = client._connection = make_connection(RESOURCE) entries = [ @@ -88,8 +90,8 @@ def test_create_dataset_w_attrs(client, PROJECT, DS_ID): before.default_table_expiration_ms = 3600 before.location = LOCATION before.labels = LABELS + before.default_rounding_mode = DEFAULT_ROUNDING_MODE after = client.create_dataset(before) - assert after.dataset_id == DS_ID assert after.project == PROJECT assert after.etag == RESOURCE["etag"] @@ -99,6 +101,7 @@ def test_create_dataset_w_attrs(client, PROJECT, DS_ID): assert after.location == LOCATION assert after.default_table_expiration_ms == 3600 assert after.labels == LABELS + assert after.default_rounding_mode == DEFAULT_ROUNDING_MODE conn.api_request.assert_called_once_with( method="POST", @@ -109,6 +112,7 @@ def test_create_dataset_w_attrs(client, PROJECT, DS_ID): "friendlyName": FRIENDLY_NAME, "location": LOCATION, "defaultTableExpirationMs": "3600", + "defaultRoundingMode": DEFAULT_ROUNDING_MODE, "access": [ {"role": "OWNER", "userByEmail": USER_EMAIL}, {"view": VIEW, "role": None}, @@ -365,3 +369,100 @@ def test_create_dataset_alreadyexists_w_exists_ok_true(PROJECT, DS_ID, LOCATION) mock.call(method="GET", path=get_path, timeout=DEFAULT_TIMEOUT), ] ) + + +def test_create_dataset_with_default_rounding_mode_if_value_is_none( + PROJECT, DS_ID, LOCATION +): + default_rounding_mode = None + path = "/projects/%s/datasets" % PROJECT + resource = { + "datasetReference": {"projectId": PROJECT, "datasetId": DS_ID}, + "etag": "etag", + "id": "{}:{}".format(PROJECT, DS_ID), + "location": LOCATION, + } + client = make_client(location=LOCATION) + conn = client._connection = make_connection(resource) + + ds_ref = DatasetReference(PROJECT, DS_ID) + before = Dataset(ds_ref) + before.default_rounding_mode = default_rounding_mode + after = client.create_dataset(before) + + assert after.dataset_id == DS_ID + assert after.project == PROJECT + assert after.default_rounding_mode is None + + conn.api_request.assert_called_once_with( + method="POST", + path=path, + data={ + "datasetReference": {"projectId": PROJECT, "datasetId": DS_ID}, + "labels": {}, + "location": LOCATION, + "defaultRoundingMode": "ROUNDING_MODE_UNSPECIFIED", + }, + timeout=DEFAULT_TIMEOUT, + ) + + +def test_create_dataset_with_default_rounding_mode_if_value_is_not_string( + PROJECT, DS_ID, LOCATION +): + default_rounding_mode = 10 + ds_ref = DatasetReference(PROJECT, DS_ID) + dataset = Dataset(ds_ref) + with pytest.raises(ValueError) as e: + dataset.default_rounding_mode = default_rounding_mode + assert str(e.value) == "Pass a string, or None" + + +def test_create_dataset_with_default_rounding_mode_if_value_is_not_in_possible_values( + PROJECT, DS_ID +): + default_rounding_mode = "ROUND_HALF_AWAY_FROM_ZEROS" + ds_ref = DatasetReference(PROJECT, DS_ID) + dataset = Dataset(ds_ref) + with pytest.raises(ValueError) as e: + dataset.default_rounding_mode = default_rounding_mode + assert ( + str(e.value) + == "rounding mode needs to be one of ROUNDING_MODE_UNSPECIFIED,ROUND_HALF_AWAY_FROM_ZERO,ROUND_HALF_EVEN" + ) + + +def test_create_dataset_with_default_rounding_mode_if_value_is_in_possible_values( + PROJECT, DS_ID, LOCATION +): + default_rounding_mode = "ROUND_HALF_AWAY_FROM_ZERO" + path = "/projects/%s/datasets" % PROJECT + resource = { + "datasetReference": {"projectId": PROJECT, "datasetId": DS_ID}, + "etag": "etag", + "id": "{}:{}".format(PROJECT, DS_ID), + "location": LOCATION, + } + client = make_client(location=LOCATION) + conn = client._connection = make_connection(resource) + + ds_ref = DatasetReference(PROJECT, DS_ID) + before = Dataset(ds_ref) + before.default_rounding_mode = default_rounding_mode + after = client.create_dataset(before) + + assert after.dataset_id == DS_ID + assert after.project == PROJECT + assert after.default_rounding_mode is None + + conn.api_request.assert_called_once_with( + method="POST", + path=path, + data={ + "datasetReference": {"projectId": PROJECT, "datasetId": DS_ID}, + "labels": {}, + "location": LOCATION, + "defaultRoundingMode": default_rounding_mode, + }, + timeout=DEFAULT_TIMEOUT, + )