diff --git a/django/contrib/gis/db/models/lookups.py b/django/contrib/gis/db/models/lookups.py index 3f3bc8acc8aa8..2f382acf913ec 100644 --- a/django/contrib/gis/db/models/lookups.py +++ b/django/contrib/gis/db/models/lookups.py @@ -1,7 +1,7 @@ import re from django.contrib.gis.db.models.fields import BaseSpatialField -from django.db.models.expressions import Col, Expression +from django.db.models.expressions import Expression from django.db.models.lookups import Lookup, Transform from django.db.models.sql.query import Query @@ -55,30 +55,19 @@ def process_band_indices(self, only_lhs=False): def get_db_prep_lookup(self, value, connection): # get_db_prep_lookup is called by process_rhs from super class - return ('%s', [connection.ops.Adapter(value)] + (self.rhs_params or [])) + return ('%s', [connection.ops.Adapter(value)]) def process_rhs(self, compiler, connection): if isinstance(self.rhs, Query): # If rhs is some Query, don't touch it. return super().process_rhs(compiler, connection) - geom = self.rhs - if isinstance(self.rhs, Col): - # Make sure the F Expression destination field exists, and - # set an `srid` attribute with the same as that of the - # destination. - geo_fld = self.rhs.output_field - if not hasattr(geo_fld, 'srid'): - raise ValueError('No geographic field found in expression.') - self.rhs.srid = geo_fld.srid - sql, _ = compiler.compile(geom) - return connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler) % sql, [] - elif isinstance(self.rhs, Expression): - raise ValueError('Complex expressions not supported for spatial fields.') + if isinstance(self.rhs, Expression): + self.rhs = self.rhs.resolve_expression(compiler.query) rhs, rhs_params = super().process_rhs(compiler, connection) - rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler) - return rhs, rhs_params + placeholder = connection.ops.get_geom_placeholder(self.lhs.output_field, self.rhs, compiler) + return placeholder % rhs, rhs_params def get_rhs_op(self, connection, rhs): # Unlike BuiltinLookup, the GIS get_rhs_op() implementation should return @@ -267,18 +256,17 @@ class RelateLookup(GISLookup): sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s)' pattern_regex = re.compile(r'^[012TF\*]{9}$') - def get_db_prep_lookup(self, value, connection): - if len(self.rhs_params) != 1: - raise ValueError('relate must be passed a two-tuple') + def process_rhs(self, compiler, connection): # Check the pattern argument + pattern = self.rhs_params[0] backend_op = connection.ops.gis_operators[self.lookup_name] if hasattr(backend_op, 'check_relate_argument'): - backend_op.check_relate_argument(self.rhs_params[0]) - else: - pattern = self.rhs_params[0] - if not isinstance(pattern, str) or not self.pattern_regex.match(pattern): - raise ValueError('Invalid intersection matrix pattern "%s".' % pattern) - return super().get_db_prep_lookup(value, connection) + backend_op.check_relate_argument(pattern) + elif not isinstance(pattern, str) or not self.pattern_regex.match(pattern): + raise ValueError('Invalid intersection matrix pattern "%s".' % pattern) + + sql, params = super().process_rhs(compiler, connection) + return sql, params + [pattern] @BaseSpatialField.register_lookup @@ -322,8 +310,8 @@ class DWithinLookup(DistanceLookupBase): def process_rhs(self, compiler, connection): dist_sql, dist_params = self.process_distance(compiler, connection) self.template_params['value'] = dist_sql - rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, self.rhs, compiler) - return rhs, [connection.ops.Adapter(self.rhs)] + dist_params + rhs_sql, params = super().process_rhs(compiler, connection) + return rhs_sql, params + dist_params class DistanceLookupFromFunction(DistanceLookupBase): diff --git a/tests/gis_tests/distapp/tests.py b/tests/gis_tests/distapp/tests.py index ec6c40eac3890..395f7226ef429 100644 --- a/tests/gis_tests/distapp/tests.py +++ b/tests/gis_tests/distapp/tests.py @@ -73,6 +73,9 @@ def test_dwithin(self): with self.subTest(dist=dist, qs=qs): self.assertEqual(tx_cities, self.get_names(qs)) + # With a complex geometry expression + self.assertFalse(SouthTexasCity.objects.exclude(point__dwithin=(Union('point', 'point'), 0))) + # Now performing the `dwithin` queries on a geodetic coordinate system. for dist in au_dists: with self.subTest(dist=dist): diff --git a/tests/gis_tests/geoapp/tests.py b/tests/gis_tests/geoapp/tests.py index 1e6d111ba5d74..6bff300837b3a 100644 --- a/tests/gis_tests/geoapp/tests.py +++ b/tests/gis_tests/geoapp/tests.py @@ -448,6 +448,18 @@ def test_relate_lookup(self): self.assertEqual('Texas', Country.objects.get(mpoly__relate=(pnt2, intersects_mask)).name) self.assertEqual('Lawrence', City.objects.get(point__relate=(ks.poly, intersects_mask)).name) + # With a complex geometry expression + mask = 'anyinteract' if oracle else within_mask + self.assertFalse(City.objects.exclude(point__relate=(functions.Union('point', 'point'), mask))) + + def test_gis_lookups_with_complex_expressions(self): + multiple_arg_lookups = {'dwithin', 'relate'} # These lookups are tested in other places. + lookups = connection.ops.gis_operators.keys() - multiple_arg_lookups + self.assertTrue(lookups, 'No lookups found') + for lookup in lookups: + with self.subTest(lookup): + City.objects.filter(**{'point__' + lookup: functions.Union('point', 'point')}).exists() + class GeoQuerySetTest(TestCase): # TODO: GeoQuerySet is removed, organize these test better.