diff --git a/django/contrib/gis/db/models/query.py b/django/contrib/gis/db/models/query.py index bdb6a670dcdc4..a043afcd9a629 100644 --- a/django/contrib/gis/db/models/query.py +++ b/django/contrib/gis/db/models/query.py @@ -1,6 +1,6 @@ from django.core.exceptions import ImproperlyConfigured from django.db import connection -from django.db.models.query import sql, QuerySet, Q +from django.db.models.query import QuerySet, Q, ValuesQuerySet, ValuesListQuerySet from django.contrib.gis.db.backend import SpatialBackend from django.contrib.gis.db.models import aggregates @@ -9,21 +9,28 @@ from django.contrib.gis.measure import Area, Distance from django.contrib.gis.models import get_srid_info -class GeomSQL(object): - "Simple wrapper object for geometric SQL." - def __init__(self, geo_sql): - self.sql = geo_sql - - def as_sql(self, *args, **kwargs): - return self.sql - class GeoQuerySet(QuerySet): "The Geographic QuerySet." + ### Methods overloaded from QuerySet ### def __init__(self, model=None, query=None): super(GeoQuerySet, self).__init__(model=model, query=query) self.query = query or GeoQuery(self.model, connection) + def values(self, *fields): + return self._clone(klass=GeoValuesQuerySet, setup=True, _fields=fields) + + def values_list(self, *fields, **kwargs): + flat = kwargs.pop('flat', False) + if kwargs: + raise TypeError('Unexpected keyword arguments to values_list: %s' + % (kwargs.keys(),)) + if flat and len(fields) > 1: + raise TypeError("'flat' is not valid when values_list is called with more than one field.") + return self._clone(klass=GeoValuesListQuerySet, setup=True, flat=flat, + _fields=fields) + + ### GeoQuerySet Methods ### def area(self, tolerance=0.05, **kwargs): """ Returns the area of the geographic field in an `area` attribute on @@ -592,3 +599,14 @@ def _geocol_select(self, geo_field, field_name): return self.query._field_column(geo_field, parent_model._meta.db_table) else: return self.query._field_column(geo_field) + +class GeoValuesQuerySet(ValuesQuerySet): + def __init__(self, *args, **kwargs): + super(GeoValuesQuerySet, self).__init__(*args, **kwargs) + # This flag tells `resolve_columns` to run the values through + # `convert_values`. This ensures that Geometry objects instead + # of string values are returned with `values()` or `values_list()`. + self.query.geo_values = True + +class GeoValuesListQuerySet(GeoValuesQuerySet, ValuesListQuerySet): + pass diff --git a/django/contrib/gis/db/models/sql/query.py b/django/contrib/gis/db/models/sql/query.py index e09cb7ce3cbdd..f802a282344f0 100644 --- a/django/contrib/gis/db/models/sql/query.py +++ b/django/contrib/gis/db/models/sql/query.py @@ -14,6 +14,8 @@ ALL_TERMS = sql.constants.QUERY_TERMS.copy() ALL_TERMS.update(SpatialBackend.gis_terms) +TABLE_NAME = sql.constants.TABLE_NAME + class GeoQuery(sql.Query): """ A single spatial SQL query. @@ -64,10 +66,15 @@ def get_columns(self, with_aliases=False): else: col_aliases = set() if self.select: + only_load = self.deferred_to_columns() # This loop customized for GeoQuery. for col, field in izip(self.select, self.select_fields): if isinstance(col, (list, tuple)): - r = self.get_field_select(field, col[0]) + alias, column = col + table = self.alias_map[alias][TABLE_NAME] + if table in only_load and col not in only_load[table]: + continue + r = self.get_field_select(field, alias) if with_aliases: if col[1] in col_aliases: c_alias = 'Col%d' % len(col_aliases) @@ -75,7 +82,7 @@ def get_columns(self, with_aliases=False): aliases.add(c_alias) col_aliases.add(c_alias) else: - result.append('%s AS %s' % (r, col[1])) + result.append('%s AS %s' % (r, qn2(col[1]))) aliases.add(r) col_aliases.add(col[1]) else: @@ -101,7 +108,7 @@ def get_columns(self, with_aliases=False): alias is not None and ' AS %s' % alias or '' ) for alias, aggregate in self.aggregate_select.items() - ]) + ]) # This loop customized for GeoQuery. for (table, col), field in izip(self.related_select_cols, self.related_select_fields): @@ -123,10 +130,14 @@ def get_default_columns(self, with_aliases=False, col_aliases=None, start_alias=None, opts=None, as_pairs=False): """ Computes the default columns for selecting every field in the base - model. + model. Will sometimes be called to pull in related models (e.g. via + select_related), in which case "opts" and "start_alias" will be given + to provide a starting point for the traversal. Returns a list of strings, quoted appropriately for use in SQL - directly, as well as a set of aliases used in the select statement. + directly, as well as a set of aliases used in the select statement (if + 'as_pairs' is True, returns a list of (alias, col_name) pairs instead + of strings as the first component and None as the second component). This routine is overridden from Query to handle customized selection of geometry columns. @@ -134,22 +145,34 @@ def get_default_columns(self, with_aliases=False, col_aliases=None, result = [] if opts is None: opts = self.model._meta - if start_alias: - table_alias = start_alias - else: - table_alias = self.tables[0] - root_pk = opts.pk.column - seen = {None: table_alias} aliases = set() + only_load = self.deferred_to_columns() + proxied_model = opts.proxy and opts.proxy_for_model or 0 + if start_alias: + seen = {None: start_alias} for field, model in opts.get_fields_with_model(): - try: - alias = seen[model] - except KeyError: - alias = self.join((table_alias, model._meta.db_table, - root_pk, model._meta.pk.column)) - seen[model] = alias + if start_alias: + try: + alias = seen[model] + except KeyError: + if model is proxied_model: + alias = start_alias + else: + link_field = opts.get_ancestor_link(model) + alias = self.join((start_alias, model._meta.db_table, + link_field.column, model._meta.pk.column)) + seen[model] = alias + else: + # If we're starting from the base model of the queryset, the + # aliases will have already been set up in pre_sql_setup(), so + # we can save time here. + alias = self.included_inherited_models[model] + table = self.alias_map[alias][TABLE_NAME] + if table in only_load and field.column not in only_load[table]: + continue if as_pairs: result.append((alias, field.column)) + aliases.add(alias) continue # This part of the function is customized for GeoQuery. We # see if there was any custom selection specified in the @@ -166,8 +189,6 @@ def get_default_columns(self, with_aliases=False, col_aliases=None, aliases.add(r) if with_aliases: col_aliases.add(field.column) - if as_pairs: - return result, None return result, aliases def resolve_columns(self, row, fields=()): @@ -191,8 +212,8 @@ def resolve_columns(self, row, fields=()): # distance objects added by GeoQuerySet methods). values = [self.convert_values(v, self.extra_select_fields.get(a, None)) for v, a in izip(row[rn_offset:index_start], aliases)] - if SpatialBackend.oracle: - # This is what happens normally in OracleQuery's `resolve_columns`. + if SpatialBackend.oracle or getattr(self, 'geo_values', False): + # We resolve the columns for value, field in izip(row[index_start:], fields): values.append(self.convert_values(value, field)) else: @@ -215,7 +236,7 @@ def convert_values(self, value, field): value = Distance(**{field.distance_att : value}) elif isinstance(field, AreaField): value = Area(**{field.area_att : value}) - elif isinstance(field, GeomField) and value: + elif isinstance(field, (GeomField, GeometryField)) and value: value = SpatialBackend.Geometry(value) return value diff --git a/django/contrib/gis/tests/relatedapp/models.py b/django/contrib/gis/tests/relatedapp/models.py index 8ea1469b3e4aa..d7dd6bbfd2ced 100644 --- a/django/contrib/gis/tests/relatedapp/models.py +++ b/django/contrib/gis/tests/relatedapp/models.py @@ -2,15 +2,16 @@ from django.contrib.localflavor.us.models import USStateField class Location(models.Model): - name = models.CharField(max_length=50) point = models.PointField() objects = models.GeoManager() + def __unicode__(self): return self.point.wkt class City(models.Model): name = models.CharField(max_length=50) state = USStateField() location = models.ForeignKey(Location) objects = models.GeoManager() + def __unicode__(self): return self.name class AugmentedLocation(Location): extra_text = models.TextField(blank=True) diff --git a/django/contrib/gis/tests/relatedapp/tests.py b/django/contrib/gis/tests/relatedapp/tests.py index 0f2c4b83ee7db..3d162b065e554 100644 --- a/django/contrib/gis/tests/relatedapp/tests.py +++ b/django/contrib/gis/tests/relatedapp/tests.py @@ -118,7 +118,7 @@ def test05_select_related_fk_to_subclass(self): # Regression test for #9752. l = list(DirectoryEntry.objects.all().select_related()) - def test6_f_expressions(self): + def test06_f_expressions(self): "Testing F() expressions on GeometryFields." # Constructing a dummy parcel border and getting the City instance for # assigning the FK. @@ -166,6 +166,31 @@ def test6_f_expressions(self): self.assertEqual(1, len(qs)) self.assertEqual('P1', qs[0].name) + def test07_values(self): + "Testing values() and values_list() and GeoQuerySets." + # GeoQuerySet and GeoValuesQuerySet, and GeoValuesListQuerySet respectively. + gqs = Location.objects.all() + gvqs = Location.objects.values() + gvlqs = Location.objects.values_list() + + # Incrementing through each of the models, dictionaries, and tuples + # returned by the different types of GeoQuerySets. + for m, d, t in zip(gqs, gvqs, gvlqs): + # The values should be Geometry objects and not raw strings returned + # by the spatial database. + self.failUnless(isinstance(d['point'], SpatialBackend.Geometry)) + self.failUnless(isinstance(t[1], SpatialBackend.Geometry)) + self.assertEqual(m.point, d['point']) + self.assertEqual(m.point, t[1]) + + # Test disabled until #10572 is resolved. + #def test08_defer_only(self): + # "Testing defer() and only() on Geographic models." + # qs = Location.objects.all() + # def_qs = Location.objects.defer('point') + # for loc, def_loc in zip(qs, def_qs): + # self.assertEqual(loc.point, def_loc.point) + # TODO: Related tests for KML, GML, and distance lookups. def suite(): diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 0583386000bb9..ce732b1046d7b 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -784,8 +784,6 @@ def get_default_columns(self, with_aliases=False, col_aliases=None, aliases.add(r) if with_aliases: col_aliases.add(field.column) - if as_pairs: - return result, aliases return result, aliases def get_from_clause(self):