Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor connect() function, cover it with unit tests #462

Merged
merged 9 commits into from
Aug 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 59 additions & 57 deletions django_spanner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,79 +18,81 @@


class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'spanner'
display_name = 'Cloud Spanner'
vendor = "spanner"
display_name = "Cloud Spanner"

# Mapping of Field objects to their column types.
# https://cloud.google.com/spanner/docs/data-types#date-type
data_types = {
'AutoField': 'INT64',
'BigAutoField': 'INT64',
'BinaryField': 'BYTES(MAX)',
'BooleanField': 'BOOL',
'CharField': 'STRING(%(max_length)s)',
'DateField': 'DATE',
'DateTimeField': 'TIMESTAMP',
'DecimalField': 'FLOAT64',
'DurationField': 'INT64',
'EmailField': 'STRING(%(max_length)s)',
'FileField': 'STRING(%(max_length)s)',
'FilePathField': 'STRING(%(max_length)s)',
'FloatField': 'FLOAT64',
'IntegerField': 'INT64',
'BigIntegerField': 'INT64',
'IPAddressField': 'STRING(15)',
'GenericIPAddressField': 'STRING(39)',
'NullBooleanField': 'BOOL',
'OneToOneField': 'INT64',
'PositiveIntegerField': 'INT64',
'PositiveSmallIntegerField': 'INT64',
'SlugField': 'STRING(%(max_length)s)',
'SmallAutoField': 'INT64',
'SmallIntegerField': 'INT64',
'TextField': 'STRING(MAX)',
'TimeField': 'TIMESTAMP',
'UUIDField': 'STRING(32)',
"AutoField": "INT64",
"BigAutoField": "INT64",
"BinaryField": "BYTES(MAX)",
"BooleanField": "BOOL",
"CharField": "STRING(%(max_length)s)",
"DateField": "DATE",
"DateTimeField": "TIMESTAMP",
"DecimalField": "FLOAT64",
"DurationField": "INT64",
"EmailField": "STRING(%(max_length)s)",
"FileField": "STRING(%(max_length)s)",
"FilePathField": "STRING(%(max_length)s)",
"FloatField": "FLOAT64",
"IntegerField": "INT64",
"BigIntegerField": "INT64",
"IPAddressField": "STRING(15)",
"GenericIPAddressField": "STRING(39)",
"NullBooleanField": "BOOL",
"OneToOneField": "INT64",
"PositiveIntegerField": "INT64",
"PositiveSmallIntegerField": "INT64",
"SlugField": "STRING(%(max_length)s)",
"SmallAutoField": "INT64",
"SmallIntegerField": "INT64",
"TextField": "STRING(MAX)",
"TimeField": "TIMESTAMP",
"UUIDField": "STRING(32)",
}
operators = {
'exact': '= %s',
'iexact': 'REGEXP_CONTAINS(%s, %%%%s)',
"exact": "= %s",
"iexact": "REGEXP_CONTAINS(%s, %%%%s)",
# contains uses REGEXP_CONTAINS instead of LIKE to allow
# DatabaseOperations.prep_for_like_query() to do regular expression
# escaping. prep_for_like_query() is called for all the lookups that
# use REGEXP_CONTAINS except regex/iregex (see
# django.db.models.lookups.PatternLookup).
'contains': 'REGEXP_CONTAINS(%s, %%%%s)',
'icontains': 'REGEXP_CONTAINS(%s, %%%%s)',
'gt': '> %s',
'gte': '>= %s',
'lt': '< %s',
'lte': '<= %s',
"contains": "REGEXP_CONTAINS(%s, %%%%s)",
"icontains": "REGEXP_CONTAINS(%s, %%%%s)",
"gt": "> %s",
"gte": ">= %s",
"lt": "< %s",
"lte": "<= %s",
# Using REGEXP_CONTAINS instead of STARTS_WITH and ENDS_WITH for the
# same reasoning as described above for 'contains'.
'startswith': 'REGEXP_CONTAINS(%s, %%%%s)',
'endswith': 'REGEXP_CONTAINS(%s, %%%%s)',
'istartswith': 'REGEXP_CONTAINS(%s, %%%%s)',
'iendswith': 'REGEXP_CONTAINS(%s, %%%%s)',
'regex': 'REGEXP_CONTAINS(%s, %%%%s)',
'iregex': 'REGEXP_CONTAINS(%s, %%%%s)',
"startswith": "REGEXP_CONTAINS(%s, %%%%s)",
"endswith": "REGEXP_CONTAINS(%s, %%%%s)",
"istartswith": "REGEXP_CONTAINS(%s, %%%%s)",
"iendswith": "REGEXP_CONTAINS(%s, %%%%s)",
"regex": "REGEXP_CONTAINS(%s, %%%%s)",
"iregex": "REGEXP_CONTAINS(%s, %%%%s)",
}

# pattern_esc is used to generate SQL pattern lookup clauses when the
# right-hand side of the lookup isn't a raw string (it might be an
# expression or the result of a bilateral transformation). In those cases,
# special characters for REGEXP_CONTAINS operators (e.g. \, *, _) must be
# escaped on database side.
pattern_esc = r'REPLACE(REPLACE(REPLACE({}, "\\", "\\\\"), "%%", r"\%%"), "_", r"\_")'
pattern_esc = (
r'REPLACE(REPLACE(REPLACE({}, "\\", "\\\\"), "%%", r"\%%"), "_", r"\_")'
)
# These are all no-ops in favor of using REGEXP_CONTAINS in the customized
# lookups.
pattern_ops = {
'contains': '',
'icontains': '',
'startswith': '',
'istartswith': '',
'endswith': '',
'iendswith': '',
"contains": "",
"icontains": "",
"startswith": "",
"istartswith": "",
"endswith": "",
"iendswith": "",
}

Database = Database
Expand All @@ -104,19 +106,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):

@property
def instance(self):
return spanner.Client().instance(self.settings_dict['INSTANCE'])
return spanner.Client().instance(self.settings_dict["INSTANCE"])

@property
def _nodb_connection(self):
raise NotImplementedError('Spanner does not have a "no db" connection.')

def get_connection_params(self):
return {
'project': self.settings_dict['PROJECT'],
'instance': self.settings_dict['INSTANCE'],
'database': self.settings_dict['NAME'],
'user_agent': 'django_spanner/0.0.1',
**self.settings_dict['OPTIONS'],
"project": self.settings_dict["PROJECT"],
"instance_id": self.settings_dict["INSTANCE"],
"database_id": self.settings_dict["NAME"],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed arg names here, other changes in the file are just about formatting

"user_agent": "django_spanner/0.0.1",
**self.settings_dict["OPTIONS"],
}

def get_new_connection(self, conn_params):
Expand All @@ -137,7 +139,7 @@ def is_usable(self):
return False
try:
# Use a cursor directly, bypassing Django's utilities.
self.connection.cursor().execute('SELECT 1')
self.connection.cursor().execute("SELECT 1")
except Database.Error:
return False
else:
Expand Down
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
max-line-length = 119

[isort]
use_parentheses=True
combine_as_imports = true
default_section = THIRDPARTY
include_trailing_comma = true
force_grid_wrap=0
line_length = 79
multi_line_output = 5
multi_line_output = 3
149 changes: 94 additions & 55 deletions spanner_dbapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,83 +4,122 @@
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

from google.cloud import spanner_v1 as spanner
"""Connection-based DB API for Cloud Spanner."""

from google.cloud import spanner_v1
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

from .connection import Connection
# These need to be included in the top-level package for PEP-0249 DB API v2.
from .exceptions import (
DatabaseError, DataError, Error, IntegrityError, InterfaceError,
InternalError, NotSupportedError, OperationalError, ProgrammingError,
DatabaseError,
DataError,
Error,
IntegrityError,
InterfaceError,
InternalError,
NotSupportedError,
OperationalError,
ProgrammingError,
Warning,
)
from .parse_utils import get_param_types
from .types import (
BINARY, DATETIME, NUMBER, ROWID, STRING, Binary, Date, DateFromTicks, Time,
TimeFromTicks, Timestamp, TimestampFromTicks,
BINARY,
DATETIME,
NUMBER,
ROWID,
STRING,
Binary,
Date,
DateFromTicks,
Time,
TimeFromTicks,
Timestamp,
TimestampFromTicks,
)
from .version import google_client_info

# Globals that MUST be defined ###
apilevel = "2.0" # Implements the Python Database API specification 2.0 version.
# We accept arguments in the format '%s' aka ANSI C print codes.
# as per https://www.python.org/dev/peps/pep-0249/#paramstyle
paramstyle = 'format'
# Threads may share the module but not connections. This is a paranoid threadsafety level,
# but it is necessary for starters to use when debugging failures. Eventually once transactions
# are working properly, we'll update the threadsafety level.
apilevel = "2.0" # supports DP-API 2.0 level.
paramstyle = "format" # ANSI C printf format codes, e.g. ...WHERE name=%s.

# Threads may share the module, but not connections. This is a paranoid threadsafety
# level, but it is necessary for starters to use when debugging failures.
# Eventually once transactions are working properly, we'll update the
# threadsafety level.
threadsafety = 1


def connect(project=None, instance=None, database=None, credentials_uri=None, user_agent=None):
def connect(instance_id, database_id, project=None, credentials=None, user_agent=None):
"""
Connect to Cloud Spanner.
Create a connection to Cloud Spanner database.

Args:
project: The id of a project that already exists.
instance: The id of an instance that already exists.
database: The name of a database that already exists.
credentials_uri: An optional string specifying where to retrieve the service
account JSON for the credentials to connect to Cloud Spanner.
:type instance_id: :class:`str`
:param instance_id: ID of the instance to connect to.
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

Returns:
The Connection object associated to the Cloud Spanner instance.
:type database_id: :class:`str`
:param database_id: The name of the database to connect to.

Raises:
Error if it encounters any unexpected inputs.
"""
if not project:
raise Error("'project' is required.")
if not instance:
raise Error("'instance' is required.")
if not database:
raise Error("'database' is required.")
:type project: :class:`str`
:param project: (Optional) The ID of the project which owns the
instances, tables and data. If not provided, will
attempt to determine from the environment.

client_kwargs = {
'project': project,
'client_info': google_client_info(user_agent),
}
if credentials_uri:
client = spanner.Client.from_service_account_json(credentials_uri, **client_kwargs)
else:
client = spanner.Client(**client_kwargs)
:type credentials: :class:`google.auth.credentials.Credentials`
:param credentials: (Optional) The authorization credentials to attach to requests.
These credentials identify this application to the service.
If none are specified, the client will attempt to ascertain
the credentials from the environment.

:rtype: :class:`google.cloud.spanner_dbapi.connection.Connection`
:returns: Connection object associated with the given Cloud Spanner resource.

:raises: :class:`ValueError` in case of given instance/database
doesn't exist.
"""
client = spanner_v1.Client(
project=project,
credentials=credentials,
client_info=google_client_info(user_agent),
)

client_instance = client.instance(instance)
if not client_instance.exists():
raise ProgrammingError("instance '%s' does not exist." % instance)
instance = client.instance(instance_id)
if not instance.exists():
raise ValueError("instance '%s' does not exist." % instance_id)

db = client_instance.database(database, pool=spanner.pool.BurstyPool())
if not db.exists():
raise ProgrammingError("database '%s' does not exist." % database)
database = instance.database(database_id, pool=spanner_v1.pool.BurstyPool())
if not database.exists():
raise ValueError("database '%s' does not exist." % database_id)

return Connection(db)
return Connection(database)


__all__ = [
'DatabaseError', 'DataError', 'Error', 'IntegrityError', 'InterfaceError',
'InternalError', 'NotSupportedError', 'OperationalError', 'ProgrammingError',
'Warning', 'DEFAULT_USER_AGENT', 'apilevel', 'connect', 'paramstyle', 'threadsafety',
'get_param_types',
'Binary', 'Date', 'DateFromTicks', 'Time', 'TimeFromTicks', 'Timestamp',
'TimestampFromTicks',
'BINARY', 'STRING', 'NUMBER', 'DATETIME', 'ROWID', 'TimestampStr',
"DatabaseError",
"DataError",
"Error",
"IntegrityError",
"InterfaceError",
"InternalError",
"NotSupportedError",
"OperationalError",
"ProgrammingError",
"Warning",
"DEFAULT_USER_AGENT",
"apilevel",
"connect",
"paramstyle",
"threadsafety",
"get_param_types",
"Binary",
"Date",
"DateFromTicks",
"Time",
"TimeFromTicks",
"Timestamp",
"TimestampFromTicks",
"BINARY",
"STRING",
"NUMBER",
"DATETIME",
"ROWID",
"TimestampStr",
]
Loading