Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed #30715 -- Fixed crash of ArrayField lookups on ArrayAgg annotations over AutoField. #11699

Merged
merged 2 commits into from
Aug 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 19 additions & 20 deletions django/contrib/postgres/fields/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def db_type(self, connection):
size = self.size or ''
return '%s[%s]' % (self.base_field.db_type(connection), size)

def cast_db_type(self, connection):
size = self.size or ''
return '%s[%s]' % (self.base_field.cast_db_type(connection), size)

def get_placeholder(self, value, compiler, connection):
return '%s::{}'.format(self.db_type(connection))

Expand Down Expand Up @@ -190,36 +194,31 @@ def formfield(self, **kwargs):
})


class ArrayCastRHSMixin:
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
cast_type = self.lhs.output_field.cast_db_type(connection)
return '%s::%s' % (rhs, cast_type), rhs_params


@ArrayField.register_lookup
class ArrayContains(lookups.DataContains):
def as_sql(self, qn, connection):
sql, params = super().as_sql(qn, connection)
sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
return sql, params
class ArrayContains(ArrayCastRHSMixin, lookups.DataContains):
pass


@ArrayField.register_lookup
class ArrayContainedBy(lookups.ContainedBy):
def as_sql(self, qn, connection):
sql, params = super().as_sql(qn, connection)
sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
return sql, params
class ArrayContainedBy(ArrayCastRHSMixin, lookups.ContainedBy):
pass


@ArrayField.register_lookup
class ArrayExact(Exact):
def as_sql(self, qn, connection):
sql, params = super().as_sql(qn, connection)
sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
return sql, params
class ArrayExact(ArrayCastRHSMixin, Exact):
pass


@ArrayField.register_lookup
class ArrayOverlap(lookups.Overlap):
def as_sql(self, qn, connection):
sql, params = super().as_sql(qn, connection)
sql = '%s::%s' % (sql, self.lhs.output_field.db_type(connection))
return sql, params
class ArrayOverlap(ArrayCastRHSMixin, lookups.Overlap):
pass


@ArrayField.register_lookup
Expand Down
22 changes: 22 additions & 0 deletions tests/postgres_tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)

try:
from django.contrib.postgres.aggregates import ArrayAgg
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.fields.array import IndexTransform, SliceTransform
from django.contrib.postgres.forms import (
Expand Down Expand Up @@ -280,6 +281,27 @@ def test_overlap_charfield(self):
[]
)

def test_lookups_autofield_array(self):
qs = NullableIntegerArrayModel.objects.filter(
field__0__isnull=False,
).values('field__0').annotate(
arrayagg=ArrayAgg('id'),
).order_by('field__0')
tests = (
('contained_by', [self.objs[1].pk, self.objs[2].pk, 0], [2]),
('contains', [self.objs[2].pk], [2]),
('exact', [self.objs[3].pk], [20]),
('overlap', [self.objs[1].pk, self.objs[3].pk], [2, 20]),
)
for lookup, value, expected in tests:
with self.subTest(lookup=lookup):
self.assertSequenceEqual(
qs.filter(
**{'arrayagg__' + lookup: value},
).values_list('field__0', flat=True),
expected,
)

def test_index(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__0=2),
Expand Down