Skip to content

Replace conditionals with as_postgresql and as_mysql in as_sql methods #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 31, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 70 additions & 91 deletions jsonfallback/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from django_mysql.utils import connection_is_mariadb


LOOKUPS_NOT_SUPPORTED_MESSAGE = (
'Lookups on JSONFields are only supported on PostgreSQL and MySQL at the moment.'
)


class JsonAdapter(jsonb.JsonAdapter):
"""
Customized psycopg2.extras.Json to allow for a custom encoder.
Expand Down Expand Up @@ -126,44 +131,33 @@ def get_lookup(self, lookup_name):


class FallbackLookup:
def as_sql(self, qn, connection):
if '.postgresql' in connection.settings_dict['ENGINE']:
return super().as_sql(qn, connection)
raise NotSupportedError(
'Lookups on JSONFields are only supported on PostgreSQL and MySQL at the moment.'
)

def as_sqlite(self, compiler, connection):
raise NotSupportedError(LOOKUPS_NOT_SUPPORTED_MESSAGE)


@FallbackJSONField.register_lookup
class DataContains(FallbackLookup, lookups.DataContains):

def as_sql(self, qn, connection):
if '.postgresql' in connection.settings_dict['ENGINE']:
return super().as_sql(qn, connection)
if '.mysql' in connection.settings_dict['ENGINE']:
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
for i, p in enumerate(rhs_params):
rhs_params[i] = p.dumps(p.adapted) # Convert JSONAdapter to str
params = lhs_params + rhs_params
return 'JSON_CONTAINS({}, {})'.format(lhs, rhs), params
raise NotSupportedError('Lookup not supported for %s' % connection.settings_dict['ENGINE'])
def as_mysql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
for i, p in enumerate(rhs_params):
rhs_params[i] = p.dumps(p.adapted) # Convert JSONAdapter to str
params = lhs_params + rhs_params
return 'JSON_CONTAINS({}, {})'.format(lhs, rhs), params


@FallbackJSONField.register_lookup
class ContainedBy(FallbackLookup, lookups.ContainedBy):

def as_sql(self, qn, connection):
if '.postgresql' in connection.settings_dict['ENGINE']:
return super().as_sql(qn, connection)
if '.mysql' in connection.settings_dict['ENGINE']:
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
for i, p in enumerate(rhs_params):
rhs_params[i] = p.dumps(p.adapted) # Convert JSONAdapter to str
params = rhs_params + lhs_params
return 'JSON_CONTAINS({}, {})'.format(rhs, lhs), params
raise NotSupportedError('Lookup not supported for %s' % connection.settings_dict['ENGINE'])
def as_mysql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
for i, p in enumerate(rhs_params):
rhs_params[i] = p.dumps(p.adapted) # Convert JSONAdapter to str
params = rhs_params + lhs_params
return 'JSON_CONTAINS({}, {})'.format(rhs, lhs), params


@FallbackJSONField.register_lookup
Expand All @@ -176,16 +170,12 @@ def get_prep_lookup(self):
)
return super().get_prep_lookup()

def as_sql(self, qn, connection):
if '.postgresql' in connection.settings_dict['ENGINE']:
return super().as_sql(qn, connection)
if '.mysql' in connection.settings_dict['ENGINE']:
lhs, lhs_params = self.process_lhs(qn, connection)
key_name = self.rhs
path = '$.{}'.format(json.dumps(key_name))
params = lhs_params + [path]
return "JSON_CONTAINS_PATH({}, 'one', %s)".format(lhs), params
raise NotSupportedError('Lookup not supported for %s' % connection.settings_dict['ENGINE'])
def as_mysql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
key_name = self.rhs
path = '$.{}'.format(json.dumps(key_name))
params = lhs_params + [path]
return "JSON_CONTAINS_PATH({}, 'one', %s)".format(lhs), params


class JSONSequencesMixin(object):
Expand All @@ -200,43 +190,35 @@ def get_prep_lookup(self):
@FallbackJSONField.register_lookup
class HasKeys(FallbackLookup, lookups.HasKeys):

def as_sql(self, qn, connection):
if '.postgresql' in connection.settings_dict['ENGINE']:
return super().as_sql(qn, connection)
if '.mysql' in connection.settings_dict['ENGINE']:
lhs, lhs_params = self.process_lhs(qn, connection)
paths = [
'$.{}'.format(json.dumps(key_name))
for key_name in self.rhs
]
params = lhs_params + paths
def as_mysql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
paths = [
'$.{}'.format(json.dumps(key_name))
for key_name in self.rhs
]
params = lhs_params + paths

sql = ['JSON_CONTAINS_PATH(', lhs, ", 'all', "]
sql.append(', '.join('%s' for _ in paths))
sql.append(')')
return ''.join(sql), params
raise NotSupportedError('Lookup not supported for %s' % connection.settings_dict['ENGINE'])
sql = ['JSON_CONTAINS_PATH(', lhs, ", 'all', "]
sql.append(', '.join('%s' for _ in paths))
sql.append(')')
return ''.join(sql), params


@FallbackJSONField.register_lookup
class HasAnyKeys(FallbackLookup, lookups.HasAnyKeys):

def as_sql(self, qn, connection):
if '.postgresql' in connection.settings_dict['ENGINE']:
return super().as_sql(qn, connection)
if '.mysql' in connection.settings_dict['ENGINE']:
lhs, lhs_params = self.process_lhs(qn, connection)
paths = [
'$.{}'.format(json.dumps(key_name))
for key_name in self.rhs
]
params = lhs_params + paths
def as_mysql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
paths = [
'$.{}'.format(json.dumps(key_name))
for key_name in self.rhs
]
params = lhs_params + paths

sql = ['JSON_CONTAINS_PATH(', lhs, ", 'one', "]
sql.append(', '.join('%s' for _ in paths))
sql.append(')')
return ''.join(sql), params
raise NotSupportedError('Lookup not supported for %s' % connection.settings_dict['ENGINE'])
sql = ['JSON_CONTAINS_PATH(', lhs, ", 'one', "]
sql.append(', '.join('%s' for _ in paths))
sql.append(')')
return ''.join(sql), params


class JSONValue(Func):
Expand Down Expand Up @@ -286,23 +268,20 @@ def process_rhs(self, compiler, connection):


class FallbackKeyTransform(jsonb.KeyTransform):
def as_sql(self, compiler, connection):
if '.postgresql' in connection.settings_dict['ENGINE']:
return super().as_sql(compiler, connection)
elif '.mysql' in connection.settings_dict['ENGINE']:
key_transforms = [self.key_name]
previous = self.lhs
while isinstance(previous, FallbackKeyTransform):
key_transforms.insert(0, previous.key_name)
previous = previous.lhs

lhs, params = compiler.compile(previous)
json_path = mysql_compile_json_path(key_transforms)
return 'JSON_EXTRACT({}, %s)'.format(lhs), params + [json_path]

raise NotSupportedError(
'Transforms on JSONFields are only supported on PostgreSQL and MySQL at the moment.'
)

def as_mysql(self, compiler, connection):
key_transforms = [self.key_name]
previous = self.lhs
while isinstance(previous, FallbackKeyTransform):
key_transforms.insert(0, previous.key_name)
previous = previous.lhs

lhs, params = compiler.compile(previous)
json_path = mysql_compile_json_path(key_transforms)
return 'JSON_EXTRACT({}, %s)'.format(lhs), params + [json_path]

def as_sqlite(self, compiler, connection):
raise NotSupportedError(LOOKUPS_NOT_SUPPORTED_MESSAGE)


class FallbackKeyTransformFactory:
Expand Down Expand Up @@ -336,8 +315,8 @@ def __init__(self, key_transform, *args, **kwargs):


class StringKeyTransformTextLookupMixin(KeyTransformTextLookupMixin):
def process_rhs(self, qn, connection):
rhs = super().process_rhs(qn, connection)
def process_rhs(self, compiler, connection):
rhs = super().process_rhs(compiler, connection)
if '.mysql' in connection.settings_dict['ENGINE']:
params = []
for p in rhs[1]:
Expand All @@ -347,8 +326,8 @@ def process_rhs(self, qn, connection):


class NonStringKeyTransformTextLookupMixin:
def process_rhs(self, qn, connection):
rhs = super().process_rhs(qn, connection)
def process_rhs(self, compiler, connection):
rhs = super().process_rhs(compiler, connection)
if '.mysql' in connection.settings_dict['ENGINE']:
params = []
for p in rhs[1]:
Expand All @@ -367,8 +346,8 @@ def process_lhs(self, compiler, connection, lhs=None):
lhs = 'LOWER(%s)' % lhs[0], lhs[1]
return lhs

def process_rhs(self, qn, connection):
rhs = super().process_rhs(qn, connection)
def process_rhs(self, compiler, connection):
rhs = super().process_rhs(compiler, connection)
if '.mysql' in connection.settings_dict['ENGINE']:
rhs = 'LOWER(%s)' % rhs[0], rhs[1]
return rhs
Expand Down