Skip to content
Permalink
Browse files
feat: add support for NUMERIC type (#86)
* feat: add support for NUMERIC type

* add tests

* fix test name

* remove unused import

* add NUMERIC to param_types

* add system tests

* test: update tests to work for emulator

* style: fix lint

Co-authored-by: larkee <larkee@users.noreply.github.com>
  • Loading branch information
larkee and larkee committed Sep 3, 2020
1 parent cbfcc8b commit a79786ec3620da21aa3ce1c8bc820dab5983531d
Showing with 155 additions and 10 deletions.
  1. +5 −0 google/cloud/spanner_v1/_helpers.py
  2. +1 −0 google/cloud/spanner_v1/param_types.py
  3. +55 −0 tests/_fixtures.py
  4. +74 −10 tests/system/test_system.py
  5. +20 −0 tests/unit/test__helpers.py
@@ -15,6 +15,7 @@
"""Helper functions for Cloud Spanner."""

import datetime
import decimal
import math

import six
@@ -127,6 +128,8 @@ def _make_value_pb(value):
return Value(string_value=value)
if isinstance(value, ListValue):
return Value(list_value=value)
if isinstance(value, decimal.Decimal):
return Value(string_value=str(value))
raise ValueError("Unknown type: %s" % (value,))


@@ -201,6 +204,8 @@ def _parse_value_pb(value_pb, field_type):
_parse_value_pb(item_pb, field_type.struct_type.fields[i].type)
for (i, item_pb) in enumerate(value_pb.list_value.values)
]
elif field_type.code == type_pb2.NUMERIC:
result = decimal.Decimal(value_pb.string_value)
else:
raise ValueError("Unknown type: %s" % (field_type,))
return result
@@ -25,6 +25,7 @@
FLOAT64 = type_pb2.Type(code=type_pb2.FLOAT64)
DATE = type_pb2.Type(code=type_pb2.DATE)
TIMESTAMP = type_pb2.Type(code=type_pb2.TIMESTAMP)
NUMERIC = type_pb2.Type(code=type_pb2.NUMERIC)


def Array(element_type): # pylint: disable=invalid-name
@@ -16,6 +16,58 @@


DDL = """\
CREATE TABLE contacts (
contact_id INT64,
first_name STRING(1024),
last_name STRING(1024),
email STRING(1024) )
PRIMARY KEY (contact_id);
CREATE TABLE contact_phones (
contact_id INT64,
phone_type STRING(1024),
phone_number STRING(1024) )
PRIMARY KEY (contact_id, phone_type),
INTERLEAVE IN PARENT contacts ON DELETE CASCADE;
CREATE TABLE all_types (
pkey INT64 NOT NULL,
int_value INT64,
int_array ARRAY<INT64>,
bool_value BOOL,
bool_array ARRAY<BOOL>,
bytes_value BYTES(16),
bytes_array ARRAY<BYTES(16)>,
date_value DATE,
date_array ARRAY<DATE>,
float_value FLOAT64,
float_array ARRAY<FLOAT64>,
string_value STRING(16),
string_array ARRAY<STRING(16)>,
timestamp_value TIMESTAMP,
timestamp_array ARRAY<TIMESTAMP>,
numeric_value NUMERIC,
numeric_array ARRAY<NUMERIC>)
PRIMARY KEY (pkey);
CREATE TABLE counters (
name STRING(1024),
value INT64 )
PRIMARY KEY (name);
CREATE TABLE string_plus_array_of_string (
id INT64,
name STRING(16),
tags ARRAY<STRING(16)> )
PRIMARY KEY (id);
CREATE INDEX name ON contacts(first_name, last_name);
CREATE TABLE users_history (
id INT64 NOT NULL,
commit_ts TIMESTAMP NOT NULL OPTIONS
(allow_commit_timestamp=true),
name STRING(MAX) NOT NULL,
email STRING(MAX),
deleted BOOL NOT NULL )
PRIMARY KEY(id, commit_ts DESC);
"""

EMULATOR_DDL = """\
CREATE TABLE contacts (
contact_id INT64,
first_name STRING(1024),
@@ -66,3 +118,6 @@
"""

DDL_STATEMENTS = [stmt.strip() for stmt in DDL.split(";") if stmt.strip()]
EMULATOR_DDL_STATEMENTS = [
stmt.strip() for stmt in EMULATOR_DDL.split(";") if stmt.strip()
]
@@ -14,6 +14,7 @@

import collections
import datetime
import decimal
import math
import operator
import os
@@ -38,6 +39,7 @@
from google.cloud.spanner_v1.proto.type_pb2 import INT64
from google.cloud.spanner_v1.proto.type_pb2 import STRING
from google.cloud.spanner_v1.proto.type_pb2 import TIMESTAMP
from google.cloud.spanner_v1.proto.type_pb2 import NUMERIC
from google.cloud.spanner_v1.proto.type_pb2 import Type

from google.cloud._helpers import UTC
@@ -52,11 +54,13 @@
from test_utils.retry import RetryResult
from test_utils.system import unique_resource_id
from tests._fixtures import DDL_STATEMENTS
from tests._fixtures import EMULATOR_DDL_STATEMENTS
from tests._helpers import OpenTelemetryBase, HAS_OPENTELEMETRY_INSTALLED


CREATE_INSTANCE = os.getenv("GOOGLE_CLOUD_TESTS_CREATE_SPANNER_INSTANCE") is not None
USE_EMULATOR = os.getenv("SPANNER_EMULATOR_HOST") is not None
SKIP_BACKUP_TESTS = os.getenv("SKIP_BACKUP_TESTS") is not None

if CREATE_INSTANCE:
INSTANCE_ID = "google-cloud" + unique_resource_id("-")
@@ -92,7 +96,8 @@ class Config(object):


def _has_all_ddl(database):
return len(database.ddl_statements) == len(DDL_STATEMENTS)
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS
return len(database.ddl_statements) == len(ddl_statements)


def _list_instances():
@@ -284,8 +289,9 @@ class TestDatabaseAPI(unittest.TestCase, _TestData):
@classmethod
def setUpClass(cls):
pool = BurstyPool(labels={"testcase": "database_api"})
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS
cls._db = Config.INSTANCE.database(
cls.DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool
cls.DATABASE_NAME, ddl_statements=ddl_statements, pool=pool
)
operation = cls._db.create()
operation.result(30) # raises on failure / timeout.
@@ -359,12 +365,13 @@ def test_update_database_ddl_with_operation_id(self):
temp_db = Config.INSTANCE.database(temp_db_id, pool=pool)
create_op = temp_db.create()
self.to_delete.append(temp_db)
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS

# We want to make sure the operation completes.
create_op.result(240) # raises on failure / timeout.
# random but shortish always start with letter
operation_id = "a" + str(uuid.uuid4())[:8]
operation = temp_db.update_ddl(DDL_STATEMENTS, operation_id=operation_id)
operation = temp_db.update_ddl(ddl_statements, operation_id=operation_id)

self.assertEqual(operation_id, operation.operation.name.split("/")[-1])

@@ -373,7 +380,7 @@ def test_update_database_ddl_with_operation_id(self):

temp_db.reload()

self.assertEqual(len(temp_db.ddl_statements), len(DDL_STATEMENTS))
self.assertEqual(len(temp_db.ddl_statements), len(ddl_statements))

def test_db_batch_insert_then_db_snapshot_read(self):
retry = RetryInstanceState(_has_all_ddl)
@@ -447,15 +454,17 @@ def _unit_of_work(transaction, name):


@unittest.skipIf(USE_EMULATOR, "Skipping backup tests")
@unittest.skipIf(SKIP_BACKUP_TESTS, "Skipping backup tests")
class TestBackupAPI(unittest.TestCase, _TestData):
DATABASE_NAME = "test_database" + unique_resource_id("_")
DATABASE_NAME_2 = "test_database2" + unique_resource_id("_")

@classmethod
def setUpClass(cls):
pool = BurstyPool(labels={"testcase": "database_api"})
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS
db1 = Config.INSTANCE.database(
cls.DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool
cls.DATABASE_NAME, ddl_statements=ddl_statements, pool=pool
)
db2 = Config.INSTANCE.database(cls.DATABASE_NAME_2, pool=pool)
cls._db = db1
@@ -736,6 +745,8 @@ def test_list_backups(self):
(OTHER_NAN,) = struct.unpack("<d", b"\x01\x00\x01\x00\x00\x00\xf8\xff")
BYTES_1 = b"Ymlu"
BYTES_2 = b"Ym9vdHM="
NUMERIC_1 = decimal.Decimal("0.123456789")
NUMERIC_2 = decimal.Decimal("1234567890")
ALL_TYPES_TABLE = "all_types"
ALL_TYPES_COLUMNS = (
"pkey",
@@ -753,9 +764,18 @@ def test_list_backups(self):
"string_array",
"timestamp_value",
"timestamp_array",
"numeric_value",
"numeric_array",
)
EMULATOR_ALL_TYPES_COLUMNS = ALL_TYPES_COLUMNS[:-2]
AllTypesRowData = collections.namedtuple("AllTypesRowData", ALL_TYPES_COLUMNS)
AllTypesRowData.__new__.__defaults__ = tuple([None for colum in ALL_TYPES_COLUMNS])
EmulatorAllTypesRowData = collections.namedtuple(
"EmulatorAllTypesRowData", EMULATOR_ALL_TYPES_COLUMNS
)
EmulatorAllTypesRowData.__new__.__defaults__ = tuple(
[None for colum in EMULATOR_ALL_TYPES_COLUMNS]
)

ALL_TYPES_ROWDATA = (
# all nulls
@@ -769,6 +789,7 @@ def test_list_backups(self):
AllTypesRowData(pkey=106, string_value=u"VALUE"),
AllTypesRowData(pkey=107, timestamp_value=SOME_TIME),
AllTypesRowData(pkey=108, timestamp_value=NANO_TIME),
AllTypesRowData(pkey=109, numeric_value=NUMERIC_1),
# empty array values
AllTypesRowData(pkey=201, int_array=[]),
AllTypesRowData(pkey=202, bool_array=[]),
@@ -777,6 +798,7 @@ def test_list_backups(self):
AllTypesRowData(pkey=205, float_array=[]),
AllTypesRowData(pkey=206, string_array=[]),
AllTypesRowData(pkey=207, timestamp_array=[]),
AllTypesRowData(pkey=208, numeric_array=[]),
# non-empty array values, including nulls
AllTypesRowData(pkey=301, int_array=[123, 456, None]),
AllTypesRowData(pkey=302, bool_array=[True, False, None]),
@@ -785,6 +807,36 @@ def test_list_backups(self):
AllTypesRowData(pkey=305, float_array=[3.1415926, 2.71828, None]),
AllTypesRowData(pkey=306, string_array=[u"One", u"Two", None]),
AllTypesRowData(pkey=307, timestamp_array=[SOME_TIME, NANO_TIME, None]),
AllTypesRowData(pkey=308, numeric_array=[NUMERIC_1, NUMERIC_2, None]),
)
EMULATOR_ALL_TYPES_ROWDATA = (
# all nulls
EmulatorAllTypesRowData(pkey=0),
# Non-null values
EmulatorAllTypesRowData(pkey=101, int_value=123),
EmulatorAllTypesRowData(pkey=102, bool_value=False),
EmulatorAllTypesRowData(pkey=103, bytes_value=BYTES_1),
EmulatorAllTypesRowData(pkey=104, date_value=SOME_DATE),
EmulatorAllTypesRowData(pkey=105, float_value=1.4142136),
EmulatorAllTypesRowData(pkey=106, string_value=u"VALUE"),
EmulatorAllTypesRowData(pkey=107, timestamp_value=SOME_TIME),
EmulatorAllTypesRowData(pkey=108, timestamp_value=NANO_TIME),
# empty array values
EmulatorAllTypesRowData(pkey=201, int_array=[]),
EmulatorAllTypesRowData(pkey=202, bool_array=[]),
EmulatorAllTypesRowData(pkey=203, bytes_array=[]),
EmulatorAllTypesRowData(pkey=204, date_array=[]),
EmulatorAllTypesRowData(pkey=205, float_array=[]),
EmulatorAllTypesRowData(pkey=206, string_array=[]),
EmulatorAllTypesRowData(pkey=207, timestamp_array=[]),
# non-empty array values, including nulls
EmulatorAllTypesRowData(pkey=301, int_array=[123, 456, None]),
EmulatorAllTypesRowData(pkey=302, bool_array=[True, False, None]),
EmulatorAllTypesRowData(pkey=303, bytes_array=[BYTES_1, BYTES_2, None]),
EmulatorAllTypesRowData(pkey=304, date_array=[SOME_DATE, None]),
EmulatorAllTypesRowData(pkey=305, float_array=[3.1415926, 2.71828, None]),
EmulatorAllTypesRowData(pkey=306, string_array=[u"One", u"Two", None]),
EmulatorAllTypesRowData(pkey=307, timestamp_array=[SOME_TIME, NANO_TIME, None]),
)


@@ -794,8 +846,9 @@ class TestSessionAPI(OpenTelemetryBase, _TestData):
@classmethod
def setUpClass(cls):
pool = BurstyPool(labels={"testcase": "session_api"})
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS
cls._db = Config.INSTANCE.database(
cls.DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool
cls.DATABASE_NAME, ddl_statements=ddl_statements, pool=pool
)
operation = cls._db.create()
operation.result(30) # raises on failure / timeout.
@@ -899,13 +952,19 @@ def test_batch_insert_then_read_all_datatypes(self):
retry = RetryInstanceState(_has_all_ddl)
retry(self._db.reload)()

if USE_EMULATOR:
all_types_columns = EMULATOR_ALL_TYPES_COLUMNS
all_types_rowdata = EMULATOR_ALL_TYPES_ROWDATA
else:
all_types_columns = ALL_TYPES_COLUMNS
all_types_rowdata = ALL_TYPES_ROWDATA
with self._db.batch() as batch:
batch.delete(ALL_TYPES_TABLE, self.ALL)
batch.insert(ALL_TYPES_TABLE, ALL_TYPES_COLUMNS, ALL_TYPES_ROWDATA)
batch.insert(ALL_TYPES_TABLE, all_types_columns, all_types_rowdata)

with self._db.snapshot(read_timestamp=batch.committed) as snapshot:
rows = list(snapshot.read(ALL_TYPES_TABLE, ALL_TYPES_COLUMNS, self.ALL))
self._check_rows_data(rows, expected=ALL_TYPES_ROWDATA)
rows = list(snapshot.read(ALL_TYPES_TABLE, all_types_columns, self.ALL))
self._check_rows_data(rows, expected=all_types_rowdata)

def test_batch_insert_or_update_then_query(self):
retry = RetryInstanceState(_has_all_ddl)
@@ -1704,9 +1763,10 @@ def test_read_w_index(self):
MY_COLUMNS = self.COLUMNS[0], self.COLUMNS[2]
EXTRA_DDL = ["CREATE INDEX contacts_by_last_name ON contacts(last_name)"]
pool = BurstyPool(labels={"testcase": "read_w_index"})
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS
temp_db = Config.INSTANCE.database(
"test_read" + unique_resource_id("_"),
ddl_statements=DDL_STATEMENTS + EXTRA_DDL,
ddl_statements=ddl_statements + EXTRA_DDL,
pool=pool,
)
operation = temp_db.create()
@@ -2282,6 +2342,10 @@ def test_execute_sql_w_date_bindings(self):
dates = [SOME_DATE, SOME_DATE + datetime.timedelta(days=1)]
self._bind_test_helper(DATE, SOME_DATE, dates)

@unittest.skipIf(USE_EMULATOR, "Skipping NUMERIC")
def test_execute_sql_w_numeric_bindings(self):
self._bind_test_helper(NUMERIC, NUMERIC_1, [NUMERIC_1, NUMERIC_2])

def test_execute_sql_w_query_param_struct(self):
NAME = "Phred"
COUNT = 123
@@ -208,6 +208,15 @@ def test_w_datetime(self):
self.assertIsInstance(value_pb, Value)
self.assertEqual(value_pb.string_value, datetime_helpers.to_rfc3339(now))

def test_w_numeric(self):
import decimal
from google.protobuf.struct_pb2 import Value

value = decimal.Decimal("9999999999999999999999999999.999999999")
value_pb = self._callFUT(value)
self.assertIsInstance(value_pb, Value)
self.assertEqual(value_pb.string_value, str(value))

def test_w_unknown_type(self):
with self.assertRaises(ValueError):
self._callFUT(object())
@@ -431,6 +440,17 @@ def test_w_struct(self):

self.assertEqual(self._callFUT(value_pb, field_type), VALUES)

def test_w_numeric(self):
import decimal
from google.protobuf.struct_pb2 import Value
from google.cloud.spanner_v1.proto.type_pb2 import Type, NUMERIC

VALUE = decimal.Decimal("99999999999999999999999999999.999999999")
field_type = Type(code=NUMERIC)
value_pb = Value(string_value=str(VALUE))

self.assertEqual(self._callFUT(value_pb, field_type), VALUE)

def test_w_unknown_type(self):
from google.protobuf.struct_pb2 import Value
from google.cloud.spanner_v1.proto.type_pb2 import Type

0 comments on commit a79786e

Please sign in to comment.