Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

queryset-refactor: Ported DateQuerySet and ValueQuerySet over and fix…

…ed most of

the related tests.


git-svn-id: http://code.djangoproject.com/svn/django/branches/queryset-refactor@6486 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
commit 988b3bbdcb52c4e20551fc8936369a7f176f48bc 1 parent bcdedbb
Malcolm Tredinnick malcolmt authored
12 django/db/models/base.py
View
@@ -338,13 +338,15 @@ def _get_FIELD_display(self, field):
def _get_next_or_previous_by_FIELD(self, field, is_next, **kwargs):
qn = connection.ops.quote_name
op = is_next and '>' or '<'
- where = '(%s %s %%s OR (%s = %%s AND %s.%s %s %%s))' % \
+ where = ['(%s %s %%s OR (%s = %%s AND %s.%s %s %%s))' % \
(qn(field.column), op, qn(field.column),
- qn(self._meta.db_table), qn(self._meta.pk.column), op)
+ qn(self._meta.db_table), qn(self._meta.pk.column), op)]
param = smart_str(getattr(self, field.attname))
- q = self.__class__._default_manager.filter(**kwargs).order_by((not is_next and '-' or '') + field.name, (not is_next and '-' or '') + self._meta.pk.name)
- q.extra(where=where, params=[param, param,
- getattr(self, self._meta.pk.attname)])
+ order_char = not is_next and '-' or ''
+ q = self.__class__._default_manager.filter(**kwargs).order_by(
+ order_char + field.name, order_char + self._meta.pk.name)
+ q = q.extra(where=where, params=[param, param,
+ getattr(self, self._meta.pk.attname)])
try:
return q[0]
except IndexError:
108 django/db/models/query.py
View
@@ -253,7 +253,6 @@ def delete(self):
def values(self, *fields):
return self._clone(klass=ValuesQuerySet, _fields=fields)
- # FIXME: Not converted yet!
def dates(self, field_name, kind, order='ASC'):
"""
Returns a list of datetime objects representing all available dates
@@ -265,8 +264,10 @@ def dates(self, field_name, kind, order='ASC'):
"'order' must be either 'ASC' or 'DESC'."
# Let the FieldDoesNotExist exception propagate.
field = self.model._meta.get_field(field_name, many_to_many=False)
- assert isinstance(field, DateField), "%r isn't a DateField." % field_name
- return self._clone(klass=DateQuerySet, _field=field, _kind=kind, _order=order)
+ assert isinstance(field, DateField), "%r isn't a DateField." \
+ % field_name
+ return self._clone(klass=DateQuerySet, _field=field, _kind=kind,
+ _order=order)
##################################################################
# PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
@@ -389,16 +390,8 @@ def __init__(self, *args, **kwargs):
self.query.select_related = False
def iterator(self):
- try:
- select, sql, params = self._get_sql_clause()
- except EmptyResultSet:
- raise StopIteration
-
- qn = connection.ops.quote_name
-
- # self._select is a dictionary, and dictionaries' key order is
- # undefined, so we convert it to a list of tuples.
- extra_select = self._select.items()
+ extra_select = self.query.extra_select.keys()
+ extra_select.sort()
# Construct two objects -- fields and field_names.
# fields is a list of Field objects to fetch.
@@ -406,39 +399,30 @@ def iterator(self):
# resulting dictionaries.
if self._fields:
if not extra_select:
- fields = [self.model._meta.get_field(f, many_to_many=False) for f in self._fields]
+ fields = [self.model._meta.get_field(f, many_to_many=False)
+ for f in self._fields]
field_names = self._fields
else:
fields = []
field_names = []
for f in self._fields:
if f in [field.name for field in self.model._meta.fields]:
- fields.append(self.model._meta.get_field(f, many_to_many=False))
+ fields.append(self.model._meta.get_field(f,
+ many_to_many=False))
field_names.append(f)
- elif not self._select.has_key(f):
- raise FieldDoesNotExist('%s has no field named %r' % (self.model._meta.object_name, f))
+ elif not self.query.extra_select.has_key(f):
+ raise FieldDoesNotExist('%s has no field named %r'
+ % (self.model._meta.object_name, f))
else: # Default to all fields.
fields = self.model._meta.fields
field_names = [f.attname for f in fields]
- columns = [f.column for f in fields]
- select = ['%s.%s' % (qn(self.model._meta.db_table), qn(c)) for c in columns]
+ self.query.add_local_columns([f.column for f in fields])
if extra_select:
- select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), qn(s[0])) for s in extra_select])
- field_names.extend([f[0] for f in extra_select])
-
- cursor = connection.cursor()
- cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params)
-
- has_resolve_columns = hasattr(self, 'resolve_columns')
- while 1:
- rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
- if not rows:
- raise StopIteration
- for row in rows:
- if has_resolve_columns:
- row = self.resolve_columns(row, fields)
- yield dict(zip(field_names, row))
+ field_names.extend([f for f in extra_select])
+
+ for row in self.query.results_iter():
+ yield dict(zip(field_names, row))
def _clone(self, klass=None, **kwargs):
c = super(ValuesQuerySet, self)._clone(klass, **kwargs)
@@ -447,60 +431,19 @@ def _clone(self, klass=None, **kwargs):
class DateQuerySet(QuerySet):
def iterator(self):
- from django.db.backends.util import typecast_timestamp
- from django.db.models.fields import DateTimeField
-
- qn = connection.ops.quote_name
- self._order_by = () # Clear this because it'll mess things up otherwise.
+ self.query = self.query.clone(klass=sql.DateQuery)
+ self.query.select = []
+ self.query.add_date_select(self._field.column, self._kind, self._order)
if self._field.null:
- self._where.append('%s.%s IS NOT NULL' % \
- (qn(self.model._meta.db_table), qn(self._field.column)))
- try:
- select, sql, params = self._get_sql_clause()
- except EmptyResultSet:
- raise StopIteration
-
- table_name = qn(self.model._meta.db_table)
- field_name = qn(self._field.column)
-
- if connection.features.allows_group_by_ordinal:
- group_by = '1'
- else:
- group_by = connection.ops.date_trunc_sql(self._kind, '%s.%s' % (table_name, field_name))
-
- sql = 'SELECT %s %s GROUP BY %s ORDER BY 1 %s' % \
- (connection.ops.date_trunc_sql(self._kind, '%s.%s' % (qn(self.model._meta.db_table),
- qn(self._field.column))), sql, group_by, self._order)
- cursor = connection.cursor()
- cursor.execute(sql, params)
-
- has_resolve_columns = hasattr(self, 'resolve_columns')
- needs_datetime_string_cast = connection.features.needs_datetime_string_cast
- dates = []
- # It would be better to use self._field here instead of DateTimeField(),
- # but in Oracle that will result in a list of datetime.date instead of
- # datetime.datetime.
- fields = [DateTimeField()]
- while 1:
- rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
- if not rows:
- return dates
- for row in rows:
- date = row[0]
- if has_resolve_columns:
- date = self.resolve_columns([date], fields)[0]
- elif needs_datetime_string_cast:
- date = typecast_timestamp(str(date))
- dates.append(date)
+ self.query.add_filter(('%s__isnull' % self._field.name, True))
+ return self.query.results_iter()
def _clone(self, klass=None, **kwargs):
c = super(DateQuerySet, self)._clone(klass, **kwargs)
c._field = self._field
c._kind = self._kind
- c._order = self._order
return c
-# XXX; Everything below here is done.
class EmptyQuerySet(QuerySet):
def __init__(self, model=None):
super(EmptyQuerySet, self).__init__(model)
@@ -517,6 +460,11 @@ def _clone(self, klass=None, **kwargs):
c._result_cache = []
return c
+ def iterator(self):
+ # This slightly odd construction is because we need an empty generator
+ # (it should raise StopIteration immediately).
+ yield iter([]).next()
+
# QOperator, QAnd and QOr are temporarily retained for backwards compatibility.
# All the old functionality is now part of the 'Q' class.
class QOperator(Q):
23 django/db/models/sql/datastructures.py
View
@@ -57,3 +57,26 @@ def as_sql(self, quote_func=None):
else:
return 'COUNT(%s)' % col
+class Date(object):
+ """
+ Add a date selection column.
+ """
+ def __init__(self, col, lookup_type, date_sql_func):
+ self.col = col
+ self.lookup_type = lookup_type
+ self.date_sql_func= date_sql_func
+
+ def relabel_aliases(self, change_map):
+ c = self.col
+ if isinstance(c, (list, tuple)):
+ self.col = (change_map.get(c[0], c[0]), c[1])
+
+ def as_sql(self, quote_func=None):
+ if not quote_func:
+ quote_func = lambda x: x
+ if isinstance(self.col, (list, tuple)):
+ col = '%s.%s' % tuple([quote_func(c) for c in self.col])
+ else:
+ col = self.col
+ return self.date_sql_func(self.lookup_type, col)
+
128 django/db/models/sql/query.py
View
@@ -11,8 +11,8 @@
from django.utils import tree
from django.db.models.sql.where import WhereNode, AND, OR
-from django.db.models.sql.datastructures import Count
-from django.db.models.fields import FieldDoesNotExist
+from django.db.models.sql.datastructures import Count, Date
+from django.db.models.fields import FieldDoesNotExist, Field
from django.contrib.contenttypes import generic
from datastructures import EmptyResultSet
from utils import handle_legacy_orderlist
@@ -54,6 +54,7 @@
SINGLE = 'single'
NONE = None
+# FIXME: Add quote_name() calls around all the tables.
class Query(object):
"""
A single SQL query.
@@ -77,8 +78,8 @@ def __init__(self, model, connection):
self.select = []
self.tables = [] # Aliases in the order they are created.
self.where = WhereNode(self)
- self.having = []
self.group_by = []
+ self.having = []
self.order_by = []
self.low_mark, self.high_mark = 0, None # Used for offset/limit
self.distinct = False
@@ -103,12 +104,14 @@ def __str__(self):
sql, params = self.as_sql()
return sql % params
- def clone(self, **kwargs):
+ def clone(self, klass=None, **kwargs):
"""
Creates a copy of the current instance. The 'kwargs' parameter can be
used by clients to update attributes after copying has taken place.
"""
- obj = self.__class__(self.model, self.connection)
+ if not klass:
+ klass = self.__class__
+ obj = klass(self.model, self.connection)
obj.table_map = self.table_map.copy()
obj.alias_map = copy.deepcopy(self.alias_map)
obj.join_map = copy.deepcopy(self.join_map)
@@ -198,7 +201,16 @@ def as_sql(self, with_limits=True):
where, params = self.where.as_sql()
if where:
result.append('WHERE %s' % where)
- result.append(' AND'.join(self.extra_where))
+ if self.extra_where:
+ if not where:
+ result.append('WHERE')
+ else:
+ result.append('AND')
+ result.append(' AND'.join(self.extra_where))
+
+ if self.group_by:
+ grouping = self.get_grouping()
+ result.append('GROUP BY %s' % ', '.join(grouping))
ordering = self.get_ordering()
if ordering:
@@ -312,12 +324,12 @@ def get_columns(self):
"""
qn = self.connection.ops.quote_name
result = []
- if self.select:
+ if self.select or self.extra_select:
for col in self.select:
if isinstance(col, (list, tuple)):
result.append('%s.%s' % (qn(col[0]), qn(col[1])))
else:
- result.append(col.as_sql())
+ result.append(col.as_sql(quote_func=qn))
else:
table_alias = self.tables[0]
result = ['%s.%s' % (table_alias, qn(f.column))
@@ -331,6 +343,21 @@ def get_columns(self):
for alias, col in extra_select])
return result
+ def get_grouping(self):
+ """
+ Returns a tuple representing the SQL elements in the "group by" clause.
+ """
+ qn = self.connection.ops.quote_name
+ result = []
+ for col in self.group_by:
+ if isinstance(col, (list, tuple)):
+ result.append('%s.%s' % (qn(col[0]), qn(col[1])))
+ elif hasattr(col, 'as_sql'):
+ result.append(col.as_sql(qn))
+ else:
+ result.append(str(col))
+ return result
+
def get_ordering(self):
"""
Returns a tuple representing the SQL elements in the "order by" clause.
@@ -339,10 +366,18 @@ def get_ordering(self):
qn = self.connection.ops.quote_name
opts = self.model._meta
result = []
- for field in handle_legacy_orderlist(ordering):
+ for field in ordering:
if field == '?':
result.append(self.connection.ops.random_function_sql())
continue
+ if isinstance(field, int):
+ if field < 0:
+ order = 'DESC'
+ field = -field
+ else:
+ order = 'ASC'
+ result.append('%s %s' % (field, order))
+ continue
if field[0] == '-':
col = field[1:]
order = 'DESC'
@@ -683,10 +718,28 @@ def clear_limits(self):
"""
self.low_mark, self.high_mark = 0, None
+ def can_filter(self):
+ """
+ Returns True if adding filters to this instance is still possible.
+
+ Typically, this means no limits or offsets have been put on the results.
+ """
+ return not (self.low_mark or self.high_mark)
+
+ def add_local_columns(self, columns):
+ """
+ Adds the given column names to the select set, assuming they come from
+ the root model (the one given in self.model).
+ """
+ table = self.model._meta.db_table
+ self.select.extend([(table, col) for col in columns])
+
def add_ordering(self, *ordering):
"""
Adds items from the 'ordering' sequence to the query's "order by"
- clause.
+ clause. These items are either field names (not column names) --
+ possibly with a direction prefix ('-' or '?') -- or ordinals,
+ corresponding to column positions in the 'select' list.
"""
self.order_by.extend(ordering)
@@ -696,14 +749,6 @@ def clear_ordering(self):
"""
self.order_by = []
- def can_filter(self):
- """
- Returns True if adding filters to this instance is still possible.
-
- Typically, this means no limits or offsets have been put on the results.
- """
- return not (self.low_mark or self.high_mark)
-
def add_count_column(self):
"""
Converts the query to do count(*) or count(distinct(pk)) in order to
@@ -713,12 +758,12 @@ def add_count_column(self):
# that it doesn't totally overwrite the select list.
if not self.distinct:
select = Count()
- # Distinct handling is now done in Count(), so don't do it at this
- # level.
- self.distinct = False
else:
select = Count((self.table_map[self.model._meta.db_table][0],
self.model._meta.pk.column), True)
+ # Distinct handling is done in Count(), so don't do it at this
+ # level.
+ self.distinct = False
self.select = [select]
self.extra_select = {}
@@ -873,6 +918,47 @@ def clear_related(self, related_field, pk_list):
values = [(related_field.column, 'NULL')]
self.do_query(self.model._meta.db_table, values, where)
+class DateQuery(Query):
+ """
+ A DateQuery is a normal query, except that it specifically selects a single
+ date field. This requires some special handling when converting the results
+ back to Python objects, so we put it in a separate class.
+ """
+ def results_iter(self):
+ """
+ Returns an iterator over the results from executing this query.
+ """
+ resolve_columns = hasattr(self, 'resolve_columns')
+ if resolve_columns:
+ from django.db.models.fields import DateTimeField
+ fields = [DateTimeField()]
+ else:
+ from django.db.backends.util import typecast_timestamp
+ needs_string_cast = self.connection.features.needs_datetime_string_cast
+
+ for rows in self.execute_sql(MULTI):
+ for row in rows:
+ date = row[0]
+ if resolve_columns:
+ date = self.resolve_columns([date], fields)[0]
+ elif needs_string_cast:
+ date = typecast_timestamp(str(date))
+ yield date
+
+ def add_date_select(self, column, lookup_type, order='ASC'):
+ """
+ Converts the query into a date extraction query.
+ """
+ alias = self.join((None, self.model._meta.db_table, None, None))
+ select = Date((alias, column), lookup_type,
+ self.connection.ops.date_trunc_sql)
+ self.select = [select]
+ self.order_by = order == 'ASC' and [1] or [-1]
+ if self.connection.features.allows_group_by_ordinal:
+ self.group_by = [1]
+ else:
+ self.group_by = [select]
+
def find_field(name, field_list, related_query):
"""
Finds a field with a specific name in a list of field instances.
Please sign in to comment.
Something went wrong with that request. Please try again.