Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed #32096 -- Fixed crash of various expressions with JSONField key transforms. #13530

Merged
merged 6 commits into from
Oct 14, 2020
2 changes: 1 addition & 1 deletion django/contrib/postgres/aggregates/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def as_sql(self, compiler, connection):
ordering_params = []
ordering_expr_sql = []
for expr in self.ordering:
expr_sql, expr_params = expr.as_sql(compiler, connection)
expr_sql, expr_params = compiler.compile(expr)
ordering_expr_sql.append(expr_sql)
ordering_params.extend(expr_params)
sql, sql_params = super().as_sql(compiler, connection, ordering=(
Expand Down
2 changes: 1 addition & 1 deletion django/contrib/postgres/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _get_expression_sql(self, compiler, schema_editor, query):
if isinstance(expression, str):
expression = F(expression)
expression = expression.resolve_expression(query=query)
sql, params = expression.as_sql(compiler, schema_editor.connection)
sql, params = compiler.compile(expression)
try:
opclass = self.opclasses[idx]
if opclass:
Expand Down
2 changes: 1 addition & 1 deletion django/db/models/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ def get_group_by_cols(self, alias=None):
return expression.get_group_by_cols(alias=alias)

def as_sql(self, compiler, connection):
return self.expression.as_sql(compiler, connection)
return compiler.compile(self.expression)

def __repr__(self):
return "{}({})".format(self.__class__.__name__, self.expression)
Expand Down
37 changes: 18 additions & 19 deletions django/db/models/fields/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,27 +369,26 @@ def as_sqlite(self, compiler, connection):


class KeyTransformIn(lookups.In):
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if not connection.features.has_native_json_field:
func = ()
def resolve_expression_parameter(self, compiler, connection, sql, param):
sql, params = super().resolve_expression_parameter(
compiler, connection, sql, param,
)
if (
not hasattr(param, 'as_sql') and
not connection.features.has_native_json_field
):
if connection.vendor == 'oracle':
func = []
for value in rhs_params:
value = json.loads(value)
function = 'JSON_QUERY' if isinstance(value, (list, dict)) else 'JSON_VALUE'
func.append("%s('%s', '$.value')" % (
function,
json.dumps({'value': value}),
))
func = tuple(func)
rhs_params = ()
elif connection.vendor == 'mysql' and connection.mysql_is_mariadb:
func = ("JSON_UNQUOTE(JSON_EXTRACT(%s, '$'))",) * len(rhs_params)
value = json.loads(param)
if isinstance(value, (list, dict)):
sql = "JSON_QUERY(%s, '$.value')"
else:
sql = "JSON_VALUE(%s, '$.value')"
params = (json.dumps({'value': value}),)
elif connection.vendor in {'sqlite', 'mysql'}:
func = ("JSON_EXTRACT(%s, '$')",) * len(rhs_params)
rhs = rhs % func
return rhs, rhs_params
sql = "JSON_EXTRACT(%s, '$')"
if connection.vendor == 'mysql' and connection.mysql_is_mariadb:
sql = 'JSON_UNQUOTE(%s)' % sql
return sql, params


class KeyTransformExact(JSONExact):
Expand Down
2 changes: 1 addition & 1 deletion django/db/models/lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def resolve_expression_parameter(self, compiler, connection, sql, param):
if hasattr(param, 'resolve_expression'):
param = param.resolve_expression(compiler.query)
if hasattr(param, 'as_sql'):
sql, params = param.as_sql(compiler, connection)
sql, params = compiler.compile(param)
return sql, params

def batch_process_rhs(self, compiler, connection, rhs=None):
Expand Down
20 changes: 20 additions & 0 deletions docs/releases/3.1.3.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,23 @@ Bugfixes
:class:`forms.JSONField <django.forms.JSONField>` and read-only
:class:`models.JSONField <django.db.models.JSONField>` values in the admin
(:ticket:`32080`).

* Fixed a regression in Django 3.1 that caused a crash of
:class:`~django.contrib.postgres.aggregates.ArrayAgg`,
:class:`~django.contrib.postgres.aggregates.JSONBAgg`, and
:class:`~django.contrib.postgres.aggregates.StringAgg` with ``ordering``
on key transforms for :class:`~django.db.models.JSONField` (:ticket:`32096`).

* Fixed a regression in Django 3.1 that caused a crash of ``__in`` lookup when
using key transforms for :class:`~django.db.models.JSONField` in the lookup
value (:ticket:`32096`).

* Fixed a regression in Django 3.1 that caused a crash of
:class:`~django.db.models.ExpressionWrapper` with key transforms for
:class:`~django.db.models.JSONField` (:ticket:`32096`).

* Fixed a regression in Django 3.1 that caused a migrations crash on PostgreSQL
when adding an
:class:`~django.contrib.postgres.constraints.ExclusionConstraint` with key
transforms for :class:`~django.db.models.JSONField` in ``expressions``
(:ticket:`32096`).
7 changes: 7 additions & 0 deletions tests/expressions_window/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,10 @@ class Employee(models.Model):
age = models.IntegerField(blank=False, null=False)
classification = models.ForeignKey('Classification', on_delete=models.CASCADE, null=True)
bonus = models.DecimalField(decimal_places=2, max_digits=15, null=True)


class Detail(models.Model):
value = models.JSONField()

class Meta:
required_db_features = {'supports_json_field'}
39 changes: 35 additions & 4 deletions tests/expressions_window/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
from django.core.exceptions import FieldError
from django.db import NotSupportedError, connection
from django.db.models import (
Avg, BooleanField, Case, F, Func, Max, Min, OuterRef, Q, RowRange,
Subquery, Sum, Value, ValueRange, When, Window, WindowFrame,
Avg, BooleanField, Case, F, Func, IntegerField, Max, Min, OuterRef, Q,
RowRange, Subquery, Sum, Value, ValueRange, When, Window, WindowFrame,
)
from django.db.models.fields.json import KeyTextTransform, KeyTransform
from django.db.models.functions import (
CumeDist, DenseRank, ExtractYear, FirstValue, Lag, LastValue, Lead,
Cast, CumeDist, DenseRank, ExtractYear, FirstValue, Lag, LastValue, Lead,
NthValue, Ntile, PercentRank, Rank, RowNumber, Upper,
)
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature

from .models import Employee
from .models import Detail, Employee


@skipUnlessDBFeature('supports_over_clause')
Expand Down Expand Up @@ -743,6 +744,36 @@ def test_window_expression_within_subquery(self):
{'department': 'Management', 'salary': 100000}
])

@skipUnlessDBFeature('supports_json_field')
def test_key_transform(self):
Detail.objects.bulk_create([
Detail(value={'department': 'IT', 'name': 'Smith', 'salary': 37000}),
Detail(value={'department': 'IT', 'name': 'Nowak', 'salary': 32000}),
Detail(value={'department': 'HR', 'name': 'Brown', 'salary': 50000}),
Detail(value={'department': 'HR', 'name': 'Smith', 'salary': 55000}),
Detail(value={'department': 'PR', 'name': 'Moore', 'salary': 90000}),
])
qs = Detail.objects.annotate(department_sum=Window(
expression=Sum(Cast(
KeyTextTransform('salary', 'value'),
output_field=IntegerField(),
)),
partition_by=[KeyTransform('department', 'value')],
order_by=[KeyTransform('name', 'value')],
)).order_by('value__department', 'department_sum')
self.assertQuerysetEqual(qs, [
('Brown', 'HR', 50000, 50000),
('Smith', 'HR', 55000, 105000),
('Nowak', 'IT', 32000, 32000),
('Smith', 'IT', 37000, 69000),
('Moore', 'PR', 90000, 90000),
], lambda entry: (
entry.value['name'],
entry.value['department'],
entry.value['salary'],
entry.department_sum,
))

def test_invalid_start_value_range(self):
msg = "start argument must be a negative integer, zero, or None, but got '3'."
with self.assertRaisesMessage(ValueError, msg):
Expand Down
26 changes: 25 additions & 1 deletion tests/model_fields/test_jsonfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
DataError, IntegrityError, NotSupportedError, OperationalError, connection,
models,
)
from django.db.models import Count, F, OuterRef, Q, Subquery, Transform, Value
from django.db.models import (
Count, ExpressionWrapper, F, IntegerField, OuterRef, Q, Subquery,
Transform, Value,
)
from django.db.models.expressions import RawSQL
from django.db.models.fields.json import (
KeyTextTransform, KeyTransform, KeyTransformFactory,
Expand Down Expand Up @@ -405,6 +408,17 @@ def test_nested_key_transform_expression(self):
[self.objs[4]],
)

def test_expression_wrapper_key_transform(self):
self.assertSequenceEqual(
NullableJSONModel.objects.annotate(
expr=ExpressionWrapper(
KeyTransform('c', 'value'),
output_field=IntegerField(),
),
).filter(expr__isnull=False),
self.objs[3:5],
)

def test_has_key(self):
self.assertSequenceEqual(
NullableJSONModel.objects.filter(value__has_key='a'),
Expand Down Expand Up @@ -700,6 +714,16 @@ def test_key_in(self):
('value__0__in', [1], [self.objs[5]]),
('value__0__in', [1, 3], [self.objs[5]]),
('value__foo__in', ['bar'], [self.objs[7]]),
(
'value__foo__in',
[KeyTransform('foo', KeyTransform('bax', 'value'))],
[self.objs[7]],
),
(
'value__foo__in',
[KeyTransform('foo', KeyTransform('bax', 'value')), 'baz'],
[self.objs[7]],
),
('value__foo__in', ['bar', 'baz'], [self.objs[7]]),
('value__bar__in', [['foo', 'bar']], [self.objs[7]]),
('value__bar__in', [['foo', 'bar'], ['a']], [self.objs[7]]),
Expand Down
12 changes: 10 additions & 2 deletions tests/postgres_tests/migrations/0002_create_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,11 @@ class Migration(migrations.Migration):
('boolean_field', models.BooleanField(null=True)),
('char_field', models.CharField(max_length=30, blank=True)),
('integer_field', models.IntegerField(null=True)),
]
('json_field', models.JSONField(null=True)),
],
options={
'required_db_vendor': 'postgresql',
},
),
migrations.CreateModel(
name='StatTestModel',
Expand All @@ -215,7 +219,10 @@ class Migration(migrations.Migration):
models.SET_NULL,
null=True,
)),
]
],
options={
'required_db_vendor': 'postgresql',
},
),
migrations.CreateModel(
name='NowTestModel',
Expand Down Expand Up @@ -296,6 +303,7 @@ class Migration(migrations.Migration):
('start', models.DateTimeField()),
('end', models.DateTimeField()),
('cancelled', models.BooleanField(default=False)),
('requirements', models.JSONField(blank=True, null=True)),
],
options={
'required_db_vendor': 'postgresql',
Expand Down
6 changes: 4 additions & 2 deletions tests/postgres_tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,17 @@ def __init__(self, *args, **kwargs):
super().__init__(models.IntegerField())


class AggregateTestModel(models.Model):
class AggregateTestModel(PostgreSQLModel):
"""
To test postgres-specific general aggregation functions
"""
char_field = models.CharField(max_length=30, blank=True)
integer_field = models.IntegerField(null=True)
boolean_field = models.BooleanField(null=True)
json_field = models.JSONField(null=True)


class StatTestModel(models.Model):
class StatTestModel(PostgreSQLModel):
"""
To test postgres-specific aggregation functions for statistics
"""
Expand All @@ -190,3 +191,4 @@ class HotelReservation(PostgreSQLModel):
start = models.DateTimeField()
end = models.DateTimeField()
cancelled = models.BooleanField(default=False)
requirements = models.JSONField(blank=True, null=True)