Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace conditionals with as_postgresql and as_mysql in as_sql methods #1

Merged
merged 4 commits into from May 31, 2019
Merged
Changes from all commits
Commits
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.

Always

Just for now

@@ -11,6 +11,11 @@
from django_mysql.utils import connection_is_mariadb


LOOKUPS_NOT_SUPPORTED_MESSAGE = (
'Lookups on JSONFields are only supported on PostgreSQL and MySQL at the moment.'
)


class JsonAdapter(jsonb.JsonAdapter):
"""
Customized psycopg2.extras.Json to allow for a custom encoder.
@@ -126,44 +131,33 @@ def get_lookup(self, lookup_name):


class FallbackLookup:
def as_sql(self, qn, connection):
if '.postgresql' in connection.settings_dict['ENGINE']:
return super().as_sql(qn, connection)
raise NotSupportedError(
'Lookups on JSONFields are only supported on PostgreSQL and MySQL at the moment.'
)

def as_sqlite(self, compiler, connection):
raise NotSupportedError(LOOKUPS_NOT_SUPPORTED_MESSAGE)


@FallbackJSONField.register_lookup
class DataContains(FallbackLookup, lookups.DataContains):

def as_sql(self, qn, connection):
if '.postgresql' in connection.settings_dict['ENGINE']:
return super().as_sql(qn, connection)
if '.mysql' in connection.settings_dict['ENGINE']:
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
for i, p in enumerate(rhs_params):
rhs_params[i] = p.dumps(p.adapted) # Convert JSONAdapter to str
params = lhs_params + rhs_params
return 'JSON_CONTAINS({}, {})'.format(lhs, rhs), params
raise NotSupportedError('Lookup not supported for %s' % connection.settings_dict['ENGINE'])
def as_mysql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
for i, p in enumerate(rhs_params):
rhs_params[i] = p.dumps(p.adapted) # Convert JSONAdapter to str
params = lhs_params + rhs_params
return 'JSON_CONTAINS({}, {})'.format(lhs, rhs), params


@FallbackJSONField.register_lookup
class ContainedBy(FallbackLookup, lookups.ContainedBy):

def as_sql(self, qn, connection):
if '.postgresql' in connection.settings_dict['ENGINE']:
return super().as_sql(qn, connection)
if '.mysql' in connection.settings_dict['ENGINE']:
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
for i, p in enumerate(rhs_params):
rhs_params[i] = p.dumps(p.adapted) # Convert JSONAdapter to str
params = rhs_params + lhs_params
return 'JSON_CONTAINS({}, {})'.format(rhs, lhs), params
raise NotSupportedError('Lookup not supported for %s' % connection.settings_dict['ENGINE'])
def as_mysql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
for i, p in enumerate(rhs_params):
rhs_params[i] = p.dumps(p.adapted) # Convert JSONAdapter to str
params = rhs_params + lhs_params
return 'JSON_CONTAINS({}, {})'.format(rhs, lhs), params


@FallbackJSONField.register_lookup
@@ -176,16 +170,12 @@ def get_prep_lookup(self):
)
return super().get_prep_lookup()

def as_sql(self, qn, connection):
if '.postgresql' in connection.settings_dict['ENGINE']:
return super().as_sql(qn, connection)
if '.mysql' in connection.settings_dict['ENGINE']:
lhs, lhs_params = self.process_lhs(qn, connection)
key_name = self.rhs
path = '$.{}'.format(json.dumps(key_name))
params = lhs_params + [path]
return "JSON_CONTAINS_PATH({}, 'one', %s)".format(lhs), params
raise NotSupportedError('Lookup not supported for %s' % connection.settings_dict['ENGINE'])
def as_mysql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
key_name = self.rhs
path = '$.{}'.format(json.dumps(key_name))
params = lhs_params + [path]
return "JSON_CONTAINS_PATH({}, 'one', %s)".format(lhs), params


class JSONSequencesMixin(object):
@@ -200,43 +190,35 @@ def get_prep_lookup(self):
@FallbackJSONField.register_lookup
class HasKeys(FallbackLookup, lookups.HasKeys):

def as_sql(self, qn, connection):
if '.postgresql' in connection.settings_dict['ENGINE']:
return super().as_sql(qn, connection)
if '.mysql' in connection.settings_dict['ENGINE']:
lhs, lhs_params = self.process_lhs(qn, connection)
paths = [
'$.{}'.format(json.dumps(key_name))
for key_name in self.rhs
]
params = lhs_params + paths
def as_mysql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
paths = [
'$.{}'.format(json.dumps(key_name))
for key_name in self.rhs
]
params = lhs_params + paths

sql = ['JSON_CONTAINS_PATH(', lhs, ", 'all', "]
sql.append(', '.join('%s' for _ in paths))
sql.append(')')
return ''.join(sql), params
raise NotSupportedError('Lookup not supported for %s' % connection.settings_dict['ENGINE'])
sql = ['JSON_CONTAINS_PATH(', lhs, ", 'all', "]
sql.append(', '.join('%s' for _ in paths))
sql.append(')')
return ''.join(sql), params


@FallbackJSONField.register_lookup
class HasAnyKeys(FallbackLookup, lookups.HasAnyKeys):

def as_sql(self, qn, connection):
if '.postgresql' in connection.settings_dict['ENGINE']:
return super().as_sql(qn, connection)
if '.mysql' in connection.settings_dict['ENGINE']:
lhs, lhs_params = self.process_lhs(qn, connection)
paths = [
'$.{}'.format(json.dumps(key_name))
for key_name in self.rhs
]
params = lhs_params + paths
def as_mysql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
paths = [
'$.{}'.format(json.dumps(key_name))
for key_name in self.rhs
]
params = lhs_params + paths

sql = ['JSON_CONTAINS_PATH(', lhs, ", 'one', "]
sql.append(', '.join('%s' for _ in paths))
sql.append(')')
return ''.join(sql), params
raise NotSupportedError('Lookup not supported for %s' % connection.settings_dict['ENGINE'])
sql = ['JSON_CONTAINS_PATH(', lhs, ", 'one', "]
sql.append(', '.join('%s' for _ in paths))
sql.append(')')
return ''.join(sql), params


class JSONValue(Func):
@@ -286,23 +268,20 @@ def process_rhs(self, compiler, connection):


class FallbackKeyTransform(jsonb.KeyTransform):
def as_sql(self, compiler, connection):
if '.postgresql' in connection.settings_dict['ENGINE']:
return super().as_sql(compiler, connection)
elif '.mysql' in connection.settings_dict['ENGINE']:
key_transforms = [self.key_name]
previous = self.lhs
while isinstance(previous, FallbackKeyTransform):
key_transforms.insert(0, previous.key_name)
previous = previous.lhs

lhs, params = compiler.compile(previous)
json_path = mysql_compile_json_path(key_transforms)
return 'JSON_EXTRACT({}, %s)'.format(lhs), params + [json_path]

raise NotSupportedError(
'Transforms on JSONFields are only supported on PostgreSQL and MySQL at the moment.'
)

def as_mysql(self, compiler, connection):
key_transforms = [self.key_name]
previous = self.lhs
while isinstance(previous, FallbackKeyTransform):
key_transforms.insert(0, previous.key_name)
previous = previous.lhs

lhs, params = compiler.compile(previous)
json_path = mysql_compile_json_path(key_transforms)
return 'JSON_EXTRACT({}, %s)'.format(lhs), params + [json_path]

def as_sqlite(self, compiler, connection):
raise NotSupportedError(LOOKUPS_NOT_SUPPORTED_MESSAGE)


class FallbackKeyTransformFactory:
@@ -336,8 +315,8 @@ def __init__(self, key_transform, *args, **kwargs):


class StringKeyTransformTextLookupMixin(KeyTransformTextLookupMixin):
def process_rhs(self, qn, connection):
rhs = super().process_rhs(qn, connection)
def process_rhs(self, compiler, connection):
rhs = super().process_rhs(compiler, connection)
if '.mysql' in connection.settings_dict['ENGINE']:
params = []
for p in rhs[1]:
@@ -347,8 +326,8 @@ def process_rhs(self, qn, connection):


class NonStringKeyTransformTextLookupMixin:
def process_rhs(self, qn, connection):
rhs = super().process_rhs(qn, connection)
def process_rhs(self, compiler, connection):
rhs = super().process_rhs(compiler, connection)
if '.mysql' in connection.settings_dict['ENGINE']:
params = []
for p in rhs[1]:
@@ -367,8 +346,8 @@ def process_lhs(self, compiler, connection, lhs=None):
lhs = 'LOWER(%s)' % lhs[0], lhs[1]
return lhs

def process_rhs(self, qn, connection):
rhs = super().process_rhs(qn, connection)
def process_rhs(self, compiler, connection):
rhs = super().process_rhs(compiler, connection)
if '.mysql' in connection.settings_dict['ENGINE']:
rhs = 'LOWER(%s)' % rhs[0], rhs[1]
return rhs
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.