Skip to content

Commit

Permalink
mysql base backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ojengwa committed Feb 16, 2016
1 parent 0188bfd commit b7be57e
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 46 deletions.
100 changes: 54 additions & 46 deletions ibu/backends/mysql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,34 @@
MySQL database backend.
Requires mysqlclient: https://pypi.python.org/pypi/mysqlclient/
MySQLdb is supported for Python 2 only: http://sourceforge.net/projects/mysql-python
MySQLdb is supported for Python 2 only:
http://sourceforge.net/projects/mysql-python
"""
from __future__ import unicode_literals

import datetime
import re
import sys
import warnings
from time import timezone

from django.conf import settings
from django.db import utils
from django.db.backends import utils as backend_utils
from django.db.backends.base.base import BaseDatabaseWrapper
from django.utils import six, timezone
from django.utils.deprecation import RemovedInDjango20Warning
from django.utils.encoding import force_str
from django.utils.functional import cached_property
from django.utils.safestring import SafeBytes, SafeText
import six
from ibu import connection as utils
from ibu.backends import utils as backend_utils
from ibu.backends.base.base import BaseDatabaseWrapper
from ibu.backends.utils import cached_property

try:
import MySQLdb as Database
except ImportError as e:
from django.core.exceptions import ImproperlyConfigured
from ibu.backends.base.base import ImproperlyConfigured
raise ImproperlyConfigured("Error loading MySQLdb module: %s" % e)

from MySQLdb.constants import CLIENT, FIELD_TYPE # isort:skip
from MySQLdb.converters import Thing2Literal, conversions # isort:skip

# Some of these import MySQLdb, so import them after checking if it's installed.
# Some of these import MySQLdb, so import them after checking if it's
# installed.
from .client import DatabaseClient # isort:skip
from .creation import DatabaseCreation # isort:skip
from .features import DatabaseFeatures # isort:skip
Expand All @@ -44,26 +43,22 @@
# inadvertently passes the version test.
version = Database.version_info
if (version < (1, 2, 1) or (version[:3] == (1, 2, 1) and
(len(version) < 5 or version[3] != 'final' or version[4] < 2))):
(len(version) < 5 or version[3] != 'final' or
version[4] < 2))):
from django.core.exceptions import ImproperlyConfigured
raise ImproperlyConfigured("MySQLdb-1.2.1p2 or newer is required; you have %s" % Database.__version__)
raise ImproperlyConfigured(
"MySQLdb-1.2.1p2 or newer is required; you have %s" %
Database.__version__)


DatabaseError = Database.DatabaseError
IntegrityError = Database.IntegrityError


def adapt_datetime_warn_on_aware_datetime(value, conv):
# Remove this function and rely on the default adapter in Django 2.0.
if settings.USE_TZ and timezone.is_aware(value):
warnings.warn(
"The MySQL database adapter received an aware datetime (%s), "
"probably from cursor.execute(). Update your code to pass a "
"naive datetime in the database connection's time zone (UTC by "
"default).", RemovedInDjango20Warning)
# This doesn't account for the database connection's timezone,
# which isn't known. (That's why this adapter is deprecated.)
value = value.astimezone(timezone.utc).replace(tzinfo=None)
# This doesn't account for the database connection's timezone,
# which isn't known. (That's why this adapter is deprecated.)
value = value.astimezone(timezone.utc).replace(tzinfo=None)
return Thing2Literal(value.strftime("%Y-%m-%d %H:%M:%S.%f"), conv)

# MySQLdb-1.2.1 returns TIME columns as timedelta -- they are more like
Expand Down Expand Up @@ -99,7 +94,7 @@ class CursorWrapper(object):
particular exception instances and reraise them with the right types.
Implemented as a wrapper, rather than a subclass, so that we aren't stuck
to the particular underlying representation returned by Connection.cursor().
to the particular underlying representation returned by Connection.cursor()
"""
codes_for_integrityerror = (1048,)

Expand All @@ -114,7 +109,8 @@ def execute(self, query, args=None):
# Map some error codes to IntegrityError, since they seem to be
# misclassified and Django would prefer the more logical place.
if e.args[0] in self.codes_for_integrityerror:
six.reraise(utils.IntegrityError, utils.IntegrityError(*tuple(e.args)), sys.exc_info()[2])
six.reraise(utils.IntegrityError, utils.IntegrityError(
*tuple(e.args)), sys.exc_info()[2])
raise

def executemany(self, query, args):
Expand All @@ -124,7 +120,8 @@ def executemany(self, query, args):
# Map some error codes to IntegrityError, since they seem to be
# misclassified and Django would prefer the more logical place.
if e.args[0] in self.codes_for_integrityerror:
six.reraise(utils.IntegrityError, utils.IntegrityError(*tuple(e.args)), sys.exc_info()[2])
six.reraise(utils.IntegrityError, utils.IntegrityError(
*tuple(e.args)), sys.exc_info()[2])
raise

def __getattr__(self, attr):
Expand All @@ -148,7 +145,7 @@ def __exit__(self, type, value, traceback):
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = 'mysql'
# This dictionary maps Field objects to their associated MySQL column
# types, as strings. Column-type strings can contain format strings; they'll
# types, as strings. Column-type strings can contain format strings; they'l
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
_data_types = {
Expand Down Expand Up @@ -183,7 +180,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
@cached_property
def data_types(self):
if self.features.supports_microsecond_precision:
return dict(self._data_types, DateTimeField='datetime(6)', TimeField='time(6)')
return dict(self._data_types, DateTimeField='datetime(6)',
TimeField='time(6)')
else:
return self._data_types

Expand All @@ -205,14 +203,18 @@ def data_types(self):
}

# The patterns below are used to generate SQL pattern lookup clauses when
# the right-hand side of the lookup isn't a raw string (it might be an expression
# the right-hand side of the lookup isn't a raw string
# (it might be an expression
# or the result of a bilateral transformation).
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
# In those cases, special characters for LIKE operators (e.g. \, *, _)
# should be
# escaped on database side.
#
# Note: we use str.format() here for readability as '%' is used as a wildcard for
# Note: we use str.format() here for readability as '%' is used as a
# wildcard for
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'),\
'%%', '\%%'), '_', '\_')"
pattern_ops = {
'contains': "LIKE BINARY CONCAT('%%', {}, '%%')",
'icontains': "LIKE CONCAT('%%', {}, '%%')",
Expand Down Expand Up @@ -248,7 +250,7 @@ def get_connection_params(self):
if settings_dict['NAME']:
kwargs['db'] = settings_dict['NAME']
if settings_dict['PASSWORD']:
kwargs['passwd'] = force_str(settings_dict['PASSWORD'])
kwargs['passwd'] = settings_dict['PASSWORD']
if settings_dict['HOST'].startswith('/'):
kwargs['unix_socket'] = settings_dict['HOST']
elif settings_dict['HOST']:
Expand All @@ -263,8 +265,8 @@ def get_connection_params(self):

def get_new_connection(self, conn_params):
conn = Database.connect(**conn_params)
conn.encoders[SafeText] = conn.encoders[six.text_type]
conn.encoders[SafeBytes] = conn.encoders[bytes]
conn.encoders[backend_utils.SafeText] = conn.encoders[six.text_type]
conn.encoders[backend_utils.SafeBytes] = conn.encoders[bytes]
return conn

def init_connection_state(self):
Expand Down Expand Up @@ -329,24 +331,29 @@ def check_constraints(self, table_names=None):
if table_names is None:
table_names = self.introspection.table_names(cursor)
for table_name in table_names:
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
primary_key_column_name = self.introspection\
.get_primary_key_column(
cursor, table_name)
if not primary_key_column_name:
continue
key_columns = self.introspection.get_key_columns(cursor, table_name)
for column_name, referenced_table_name, referenced_column_name in key_columns:
key_columns = self.introspection.get_key_columns(
cursor, table_name)
for column_name, referenced_table_name, \
referenced_column_name in key_columns:
cursor.execute("""
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
LEFT JOIN `%s` as REFERRED
ON (REFERRING.`%s` = REFERRED.`%s`)
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL"""
% (primary_key_column_name, column_name, table_name, referenced_table_name,
column_name, referenced_column_name, column_name, referenced_column_name))
% (primary_key_column_name, column_name, table_name, referenced_table_name,
column_name, referenced_column_name, column_name, referenced_column_name))
for bad_row in cursor.fetchall():
raise utils.IntegrityError("The row in table '%s' with primary key '%s' has an invalid "
"foreign key: %s.%s contains a value '%s' that does not have a corresponding value in %s.%s."
% (table_name, bad_row[0],
table_name, column_name, bad_row[1],
referenced_table_name, referenced_column_name))
"foreign key: %s.%s contains a value '%s' that does not have a corresponding value in %s.%s."
% (table_name, bad_row[0],
table_name, column_name, bad_row[
1],
referenced_table_name, referenced_column_name))

def is_usable(self):
try:
Expand All @@ -362,5 +369,6 @@ def mysql_version(self):
server_info = self.connection.get_server_info()
match = server_version_re.match(server_info)
if not match:
raise Exception('Unable to determine MySQL version from version string %r' % server_info)
raise Exception(
'Unable to determine MySQL version from version string %r' % server_info)
return tuple(int(x) for x in match.groups())
116 changes: 116 additions & 0 deletions ibu/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import logging
from time import time

import six


logger = logging.getLogger('ibu.backends')

Expand Down Expand Up @@ -231,3 +233,117 @@ def __get__(self, instance, cls=None):
return self
res = instance.__dict__[self.name] = self.func(instance)
return res


def curry(_curried_func, *args, **kwargs):
def _curried(*moreargs, **morekwargs):
return _curried_func(*(args + moreargs), **dict(kwargs, **morekwargs))
return _curried


class EscapeData(object):
pass


class EscapeBytes(bytes, EscapeData):
"""
A byte string that should be HTML-escaped when output.
"""
pass


class EscapeText(six.text_type, EscapeData):
"""
A unicode string object that should be HTML-escaped when output.
"""
pass

if six.PY3:
EscapeString = EscapeText
else:
EscapeString = EscapeBytes
# backwards compatibility for Python 2
EscapeUnicode = EscapeText


class SafeData(object):
def __html__(self):
"""
Returns the html representation of a string for interoperability.
This allows other template engines to understand Django's SafeData.
"""
return self


class SafeBytes(bytes, SafeData):
"""
A bytes subclass that has been specifically marked as "safe" (requires no
further escaping) for HTML output purposes.
"""

def __add__(self, rhs):
"""
Concatenating a safe byte string with another safe byte string or safe
unicode string is safe. Otherwise, the result is no longer safe.
"""
t = super(SafeBytes, self).__add__(rhs)
if isinstance(rhs, SafeText):
return SafeText(t)
elif isinstance(rhs, SafeBytes):
return SafeBytes(t)
return t

def _proxy_method(self, *args, **kwargs):
"""
Wrap a call to a normal unicode method up so that we return safe
results. The method that is being wrapped is passed in the 'method'
argument.
"""
method = kwargs.pop('method')
data = method(self, *args, **kwargs)
if isinstance(data, bytes):
return SafeBytes(data)
else:
return SafeText(data)

decode = curry(_proxy_method, method=bytes.decode)


class SafeText(six.text_type, SafeData):
"""
A unicode (Python 2) / str (Python 3) subclass that has been specifically
marked as "safe" for HTML output purposes.
"""

def __add__(self, rhs):
"""
Concatenating a safe unicode string with another safe byte string or
safe unicode string is safe. Otherwise, the result is no longer safe.
"""
t = super(SafeText, self).__add__(rhs)
if isinstance(rhs, SafeData):
return SafeText(t)
return t

def _proxy_method(self, *args, **kwargs):
"""
Wrap a call to a normal unicode method up so that we return safe
results. The method that is being wrapped is passed in the 'method'
argument.
"""
method = kwargs.pop('method')
data = method(self, *args, **kwargs)
if isinstance(data, bytes):
return SafeBytes(data)
else:
return SafeText(data)

encode = curry(_proxy_method, method=six.text_type.encode)

if six.PY3:
SafeString = SafeText
else:
SafeString = SafeBytes
# backwards compatibility for Python 2
SafeUnicode = SafeText

0 comments on commit b7be57e

Please sign in to comment.