Skip to content

Commit

Permalink
Fixed #9871 -- Geometry objects are now returned in dictionaries and …
Browse files Browse the repository at this point in the history
…tuples returned by `values()` and `values_list()`, respectively; updated `GeoQuery` methods to be compatible with `defer()` and `only`; removed defunct `GeomSQL` class; and removed redundant logic from `Query.get_default_columns`.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@10326 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information
jbronn committed Apr 1, 2009
1 parent f1c6481 commit 03de1fe
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 35 deletions.
36 changes: 27 additions & 9 deletions django/contrib/gis/db/models/query.py
@@ -1,6 +1,6 @@
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import connection from django.db import connection
from django.db.models.query import sql, QuerySet, Q from django.db.models.query import QuerySet, Q, ValuesQuerySet, ValuesListQuerySet


from django.contrib.gis.db.backend import SpatialBackend from django.contrib.gis.db.backend import SpatialBackend
from django.contrib.gis.db.models import aggregates from django.contrib.gis.db.models import aggregates
Expand All @@ -9,21 +9,28 @@
from django.contrib.gis.measure import Area, Distance from django.contrib.gis.measure import Area, Distance
from django.contrib.gis.models import get_srid_info from django.contrib.gis.models import get_srid_info


class GeomSQL(object):
"Simple wrapper object for geometric SQL."
def __init__(self, geo_sql):
self.sql = geo_sql

def as_sql(self, *args, **kwargs):
return self.sql

class GeoQuerySet(QuerySet): class GeoQuerySet(QuerySet):
"The Geographic QuerySet." "The Geographic QuerySet."


### Methods overloaded from QuerySet ###
def __init__(self, model=None, query=None): def __init__(self, model=None, query=None):
super(GeoQuerySet, self).__init__(model=model, query=query) super(GeoQuerySet, self).__init__(model=model, query=query)
self.query = query or GeoQuery(self.model, connection) self.query = query or GeoQuery(self.model, connection)


def values(self, *fields):
return self._clone(klass=GeoValuesQuerySet, setup=True, _fields=fields)

def values_list(self, *fields, **kwargs):
flat = kwargs.pop('flat', False)
if kwargs:
raise TypeError('Unexpected keyword arguments to values_list: %s'
% (kwargs.keys(),))
if flat and len(fields) > 1:
raise TypeError("'flat' is not valid when values_list is called with more than one field.")
return self._clone(klass=GeoValuesListQuerySet, setup=True, flat=flat,
_fields=fields)

### GeoQuerySet Methods ###
def area(self, tolerance=0.05, **kwargs): def area(self, tolerance=0.05, **kwargs):
""" """
Returns the area of the geographic field in an `area` attribute on Returns the area of the geographic field in an `area` attribute on
Expand Down Expand Up @@ -592,3 +599,14 @@ def _geocol_select(self, geo_field, field_name):
return self.query._field_column(geo_field, parent_model._meta.db_table) return self.query._field_column(geo_field, parent_model._meta.db_table)
else: else:
return self.query._field_column(geo_field) return self.query._field_column(geo_field)

class GeoValuesQuerySet(ValuesQuerySet):
def __init__(self, *args, **kwargs):
super(GeoValuesQuerySet, self).__init__(*args, **kwargs)
# This flag tells `resolve_columns` to run the values through
# `convert_values`. This ensures that Geometry objects instead
# of string values are returned with `values()` or `values_list()`.
self.query.geo_values = True

class GeoValuesListQuerySet(GeoValuesQuerySet, ValuesListQuerySet):
pass
65 changes: 43 additions & 22 deletions django/contrib/gis/db/models/sql/query.py
Expand Up @@ -14,6 +14,8 @@
ALL_TERMS = sql.constants.QUERY_TERMS.copy() ALL_TERMS = sql.constants.QUERY_TERMS.copy()
ALL_TERMS.update(SpatialBackend.gis_terms) ALL_TERMS.update(SpatialBackend.gis_terms)


TABLE_NAME = sql.constants.TABLE_NAME

class GeoQuery(sql.Query): class GeoQuery(sql.Query):
""" """
A single spatial SQL query. A single spatial SQL query.
Expand Down Expand Up @@ -64,18 +66,23 @@ def get_columns(self, with_aliases=False):
else: else:
col_aliases = set() col_aliases = set()
if self.select: if self.select:
only_load = self.deferred_to_columns()
# This loop customized for GeoQuery. # This loop customized for GeoQuery.
for col, field in izip(self.select, self.select_fields): for col, field in izip(self.select, self.select_fields):
if isinstance(col, (list, tuple)): if isinstance(col, (list, tuple)):
r = self.get_field_select(field, col[0]) alias, column = col
table = self.alias_map[alias][TABLE_NAME]
if table in only_load and col not in only_load[table]:
continue
r = self.get_field_select(field, alias)
if with_aliases: if with_aliases:
if col[1] in col_aliases: if col[1] in col_aliases:
c_alias = 'Col%d' % len(col_aliases) c_alias = 'Col%d' % len(col_aliases)
result.append('%s AS %s' % (r, c_alias)) result.append('%s AS %s' % (r, c_alias))
aliases.add(c_alias) aliases.add(c_alias)
col_aliases.add(c_alias) col_aliases.add(c_alias)
else: else:
result.append('%s AS %s' % (r, col[1])) result.append('%s AS %s' % (r, qn2(col[1])))
aliases.add(r) aliases.add(r)
col_aliases.add(col[1]) col_aliases.add(col[1])
else: else:
Expand All @@ -101,7 +108,7 @@ def get_columns(self, with_aliases=False):
alias is not None and ' AS %s' % alias or '' alias is not None and ' AS %s' % alias or ''
) )
for alias, aggregate in self.aggregate_select.items() for alias, aggregate in self.aggregate_select.items()
]) ])


# This loop customized for GeoQuery. # This loop customized for GeoQuery.
for (table, col), field in izip(self.related_select_cols, self.related_select_fields): for (table, col), field in izip(self.related_select_cols, self.related_select_fields):
Expand All @@ -123,33 +130,49 @@ def get_default_columns(self, with_aliases=False, col_aliases=None,
start_alias=None, opts=None, as_pairs=False): start_alias=None, opts=None, as_pairs=False):
""" """
Computes the default columns for selecting every field in the base Computes the default columns for selecting every field in the base
model. model. Will sometimes be called to pull in related models (e.g. via
select_related), in which case "opts" and "start_alias" will be given
to provide a starting point for the traversal.
Returns a list of strings, quoted appropriately for use in SQL Returns a list of strings, quoted appropriately for use in SQL
directly, as well as a set of aliases used in the select statement. directly, as well as a set of aliases used in the select statement (if
'as_pairs' is True, returns a list of (alias, col_name) pairs instead
of strings as the first component and None as the second component).
This routine is overridden from Query to handle customized selection of This routine is overridden from Query to handle customized selection of
geometry columns. geometry columns.
""" """
result = [] result = []
if opts is None: if opts is None:
opts = self.model._meta opts = self.model._meta
if start_alias:
table_alias = start_alias
else:
table_alias = self.tables[0]
root_pk = opts.pk.column
seen = {None: table_alias}
aliases = set() aliases = set()
only_load = self.deferred_to_columns()
proxied_model = opts.proxy and opts.proxy_for_model or 0
if start_alias:
seen = {None: start_alias}
for field, model in opts.get_fields_with_model(): for field, model in opts.get_fields_with_model():
try: if start_alias:
alias = seen[model] try:
except KeyError: alias = seen[model]
alias = self.join((table_alias, model._meta.db_table, except KeyError:
root_pk, model._meta.pk.column)) if model is proxied_model:
seen[model] = alias alias = start_alias
else:
link_field = opts.get_ancestor_link(model)
alias = self.join((start_alias, model._meta.db_table,
link_field.column, model._meta.pk.column))
seen[model] = alias
else:
# If we're starting from the base model of the queryset, the
# aliases will have already been set up in pre_sql_setup(), so
# we can save time here.
alias = self.included_inherited_models[model]
table = self.alias_map[alias][TABLE_NAME]
if table in only_load and field.column not in only_load[table]:
continue
if as_pairs: if as_pairs:
result.append((alias, field.column)) result.append((alias, field.column))
aliases.add(alias)
continue continue
# This part of the function is customized for GeoQuery. We # This part of the function is customized for GeoQuery. We
# see if there was any custom selection specified in the # see if there was any custom selection specified in the
Expand All @@ -166,8 +189,6 @@ def get_default_columns(self, with_aliases=False, col_aliases=None,
aliases.add(r) aliases.add(r)
if with_aliases: if with_aliases:
col_aliases.add(field.column) col_aliases.add(field.column)
if as_pairs:
return result, None
return result, aliases return result, aliases


def resolve_columns(self, row, fields=()): def resolve_columns(self, row, fields=()):
Expand All @@ -191,8 +212,8 @@ def resolve_columns(self, row, fields=()):
# distance objects added by GeoQuerySet methods). # distance objects added by GeoQuerySet methods).
values = [self.convert_values(v, self.extra_select_fields.get(a, None)) values = [self.convert_values(v, self.extra_select_fields.get(a, None))
for v, a in izip(row[rn_offset:index_start], aliases)] for v, a in izip(row[rn_offset:index_start], aliases)]
if SpatialBackend.oracle: if SpatialBackend.oracle or getattr(self, 'geo_values', False):
# This is what happens normally in OracleQuery's `resolve_columns`. # We resolve the columns
for value, field in izip(row[index_start:], fields): for value, field in izip(row[index_start:], fields):
values.append(self.convert_values(value, field)) values.append(self.convert_values(value, field))
else: else:
Expand All @@ -215,7 +236,7 @@ def convert_values(self, value, field):
value = Distance(**{field.distance_att : value}) value = Distance(**{field.distance_att : value})
elif isinstance(field, AreaField): elif isinstance(field, AreaField):
value = Area(**{field.area_att : value}) value = Area(**{field.area_att : value})
elif isinstance(field, GeomField) and value: elif isinstance(field, (GeomField, GeometryField)) and value:
value = SpatialBackend.Geometry(value) value = SpatialBackend.Geometry(value)
return value return value


Expand Down
3 changes: 2 additions & 1 deletion django/contrib/gis/tests/relatedapp/models.py
Expand Up @@ -2,15 +2,16 @@
from django.contrib.localflavor.us.models import USStateField from django.contrib.localflavor.us.models import USStateField


class Location(models.Model): class Location(models.Model):
name = models.CharField(max_length=50)
point = models.PointField() point = models.PointField()
objects = models.GeoManager() objects = models.GeoManager()
def __unicode__(self): return self.point.wkt


class City(models.Model): class City(models.Model):
name = models.CharField(max_length=50) name = models.CharField(max_length=50)
state = USStateField() state = USStateField()
location = models.ForeignKey(Location) location = models.ForeignKey(Location)
objects = models.GeoManager() objects = models.GeoManager()
def __unicode__(self): return self.name


class AugmentedLocation(Location): class AugmentedLocation(Location):
extra_text = models.TextField(blank=True) extra_text = models.TextField(blank=True)
Expand Down
27 changes: 26 additions & 1 deletion django/contrib/gis/tests/relatedapp/tests.py
Expand Up @@ -118,7 +118,7 @@ def test05_select_related_fk_to_subclass(self):
# Regression test for #9752. # Regression test for #9752.
l = list(DirectoryEntry.objects.all().select_related()) l = list(DirectoryEntry.objects.all().select_related())


def test6_f_expressions(self): def test06_f_expressions(self):
"Testing F() expressions on GeometryFields." "Testing F() expressions on GeometryFields."
# Constructing a dummy parcel border and getting the City instance for # Constructing a dummy parcel border and getting the City instance for
# assigning the FK. # assigning the FK.
Expand Down Expand Up @@ -166,6 +166,31 @@ def test6_f_expressions(self):
self.assertEqual(1, len(qs)) self.assertEqual(1, len(qs))
self.assertEqual('P1', qs[0].name) self.assertEqual('P1', qs[0].name)


def test07_values(self):
"Testing values() and values_list() and GeoQuerySets."
# GeoQuerySet and GeoValuesQuerySet, and GeoValuesListQuerySet respectively.
gqs = Location.objects.all()
gvqs = Location.objects.values()
gvlqs = Location.objects.values_list()

# Incrementing through each of the models, dictionaries, and tuples
# returned by the different types of GeoQuerySets.
for m, d, t in zip(gqs, gvqs, gvlqs):
# The values should be Geometry objects and not raw strings returned
# by the spatial database.
self.failUnless(isinstance(d['point'], SpatialBackend.Geometry))
self.failUnless(isinstance(t[1], SpatialBackend.Geometry))
self.assertEqual(m.point, d['point'])
self.assertEqual(m.point, t[1])

# Test disabled until #10572 is resolved.
#def test08_defer_only(self):
# "Testing defer() and only() on Geographic models."
# qs = Location.objects.all()
# def_qs = Location.objects.defer('point')
# for loc, def_loc in zip(qs, def_qs):
# self.assertEqual(loc.point, def_loc.point)

# TODO: Related tests for KML, GML, and distance lookups. # TODO: Related tests for KML, GML, and distance lookups.


def suite(): def suite():
Expand Down
2 changes: 0 additions & 2 deletions django/db/models/sql/query.py
Expand Up @@ -784,8 +784,6 @@ def get_default_columns(self, with_aliases=False, col_aliases=None,
aliases.add(r) aliases.add(r)
if with_aliases: if with_aliases:
col_aliases.add(field.column) col_aliases.add(field.column)
if as_pairs:
return result, aliases
return result, aliases return result, aliases


def get_from_clause(self): def get_from_clause(self):
Expand Down

0 comments on commit 03de1fe

Please sign in to comment.