Skip to content

Commit

Permalink
Refactor expressions into smaller components
Browse files Browse the repository at this point in the history
Inspiration for splitting the expression classes comes from PeeWee
which has a fairly good object model.

Split ExpressionNode into smaller units

All tests passing after refactor

Clean up expression implementation

Summarised aggregate state is passed through prepare

Make Col types into expressions

Small tweaks to improve the Func type

Fix regression with non-aggregate annotations

Tests for custom sql functions

Change arguments name to expressions

Fix documentation related to expression refactor

fix regressions when renaming template params
  • Loading branch information
jarshwah committed May 16, 2014
1 parent 047808c commit 362fec1
Show file tree
Hide file tree
Showing 20 changed files with 607 additions and 472 deletions.
2 changes: 1 addition & 1 deletion django/contrib/contenttypes/fields.py
Expand Up @@ -10,7 +10,7 @@
from django.db.models.base import ModelBase
from django.db.models.fields.related import ForeignObject, ForeignObjectRel
from django.db.models.related import PathInfo
from django.db.models.sql.datastructures import Col
from django.db.models.expressions import Col
from django.contrib.contenttypes.models import ContentType
from django.utils.encoding import smart_text, python_2_unicode_compatible

Expand Down
4 changes: 2 additions & 2 deletions django/contrib/gis/db/backends/oracle/operations.py
Expand Up @@ -282,9 +282,9 @@ def spatial_aggregate_sql(self, agg):
if agg_name == 'union':
agg_name += 'agg'
if agg.is_extent:
sql_template = '%(function)s(%(field)s)'
sql_template = '%(function)s(%(expressions)s)'
else:
sql_template = '%(function)s(SDOAGGRTYPE(%(field)s,%(tolerance)s))'
sql_template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))'
sql_function = getattr(self, agg_name)
return self.select % sql_template, sql_function

Expand Down
2 changes: 1 addition & 1 deletion django/contrib/gis/db/backends/postgis/operations.py
Expand Up @@ -566,7 +566,7 @@ def spatial_aggregate_sql(self, agg):
agg_name = agg_name.lower()
if agg_name == 'union':
agg_name += 'agg'
sql_template = '%(function)s(%(field)s)'
sql_template = '%(function)s(%(expressions)s)'
sql_function = getattr(self, agg_name)
return sql_template, sql_function

Expand Down
2 changes: 1 addition & 1 deletion django/contrib/gis/db/backends/spatialite/operations.py
Expand Up @@ -315,7 +315,7 @@ def spatial_aggregate_sql(self, agg):
agg_name = agg_name.lower()
if agg_name == 'union':
agg_name += 'agg'
sql_template = self.select % '%(function)s(%(field)s)'
sql_template = self.select % '%(function)s(%(expressions)s)'
sql_function = getattr(self, agg_name)
return sql_template, sql_function

Expand Down
20 changes: 9 additions & 11 deletions django/contrib/gis/db/models/aggregates.py
Expand Up @@ -5,8 +5,8 @@


class GeoAggregate(Aggregate):
sql_template = None
sql_function = None
template = None
function = None
is_extent = False
conversion_class = None # TODO: is this still used?

Expand All @@ -16,17 +16,15 @@ def as_sql(self, compiler, connection):
self.tolerance = 0.05
self.extra['tolerance'] = self.tolerance

sql_template, sql_function = connection.ops.spatial_aggregate_sql(self)
if sql_template is None:
sql_template = '%(function)s(%(field)s)'
if self.extra.get('sql_template', None) is None:
self.extra['sql_template'] = sql_template
if not 'sql_function' in self.extra:
self.extra['sql_function'] = sql_function

template, function = connection.ops.spatial_aggregate_sql(self)
if template is None:
template = '%(function)s(%(expressions)s)'
self.extra['template'] = self.extra.get('template', template)
self.extra['function'] = self.extra.get('function', function)
return super(GeoAggregate, self).as_sql(compiler, connection)

def prepare(self, query=None, allow_joins=True, reuse=None):
def prepare(self, query=None, allow_joins=True, reuse=None, summarise=False):
self.is_summary = summarise
super(GeoAggregate, self).prepare(query, allow_joins, reuse)
if not isinstance(self.expression.output_type, GeometryField):
raise ValueError('Geospatial aggregates only allowed on geometry fields.')
Expand Down
4 changes: 2 additions & 2 deletions django/contrib/gis/db/models/sql/aggregates.py
Expand Up @@ -14,7 +14,7 @@

class GeoAggregate(Aggregate):
# Default SQL template for spatial aggregates.
sql_template = '%(function)s(%(field)s)'
sql_template = '%(function)s(%(expressions)s)'

# Conversion class, if necessary.
conversion_class = None
Expand Down Expand Up @@ -51,7 +51,7 @@ def as_sql(self, qn, connection):

substitutions = {
'function': sql_function,
'field': field_name
'expressions': field_name
}
substitutions.update(self.extra)

Expand Down
2 changes: 1 addition & 1 deletion django/db/models/__init__.py
Expand Up @@ -4,7 +4,7 @@

from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured # NOQA
from django.db.models.query import Q, QuerySet, Prefetch # NOQA
from django.db.models.expressions import ExpressionNode, F, Value, WrappedExpression # NOQA
from django.db.models.expressions import ExpressionNode, F, Value, Func # NOQA
from django.db.models.manager import Manager # NOQA
from django.db.models.base import Model # NOQA
from django.db.models.aggregates import * # NOQA
Expand Down
50 changes: 27 additions & 23 deletions django/db/models/aggregates.py
Expand Up @@ -2,7 +2,7 @@
Classes to represent the definitions of aggregate functions.
"""
from django.core.exceptions import FieldError
from django.db.models.expressions import WrappedExpression, Value
from django.db.models.expressions import Func, Value
from django.db.models.fields import IntegerField, FloatField

__all__ = [
Expand All @@ -13,12 +13,17 @@
float_field = FloatField()


class Aggregate(WrappedExpression):
class Aggregate(Func):
is_aggregate = True
name = None

def __init__(self, expression, output_type=None, **extra):
super(Aggregate, self).__init__(expression, output_type, **extra)
super(Aggregate, self).__init__(
expression,
output_type=output_type,
**extra)

self.expression = self.expressions[0]
if self.expression.is_aggregate:
raise FieldError("Cannot compute %s(%s(..)): aggregates cannot be nested" % (
self.name, expression.name))
Expand All @@ -29,8 +34,9 @@ def __init__(self, expression, output_type=None, **extra):
elif self.is_computed:
self.source = float_field

def prepare(self, query=None, allow_joins=True, reuse=None):
if self.expression.validate_name: # simple lookup
def prepare(self, query=None, allow_joins=True, reuse=None, summarise=False):
self.is_summary = summarise
if hasattr(self.expression, 'name'): # simple lookup
name = self.expression.name
reffed, _ = self.expression.contains_aggregate(query.annotations)
if reffed and not self.is_summary:
Expand All @@ -45,22 +51,22 @@ def prepare(self, query=None, allow_joins=True, reuse=None):
self.expression.col = (None, name)
return
self._patch_aggregate(query) # backward-compatibility support
super(Aggregate, self).prepare(query, allow_joins, reuse)
super(Aggregate, self).prepare(query, allow_joins, reuse, summarise)

def refs_field(self, aggregate_types, field_types):
return (isinstance(self, aggregate_types) and
isinstance(self.expression.source, field_types))

def _default_alias(self):
if hasattr(self.expression, 'name') and self.expression.validate_name:
@property
def default_alias(self):
if hasattr(self.expression, 'name'):
return '%s__%s' % (self.expression.name, self.name.lower())
raise TypeError("Complex expressions require an alias")
default_alias = property(_default_alias)

def _patch_aggregate(self, query):
"""
Helper method for patching 3rd party aggregates that do not yet support
the new way of subclassing.
the new way of subclassing. This method should be removed in 2.0
add_to_query(query, alias, col, source, is_summary) will be defined on
legacy aggregates which, in turn, instantiates the SQL implementation of
Expand All @@ -76,32 +82,30 @@ def add_to_query(self, query, alias, col, source, is_summary):
By supplying a known alias, we can get the SQLAggregate out of the aggregates
dict, and use the sql_function and sql_template attributes to patch *this* aggregate.
"""
if not hasattr(self, 'add_to_query') or self.sql_function is not None:
if not hasattr(self, 'add_to_query') or self.function is not None:
return

# raise a deprecation warning?

placeholder_alias = "_XXXXXXXX_"
self.add_to_query(query, placeholder_alias, None, None, None)
sql_aggregate = query.aggregates.pop(placeholder_alias)
if 'sql_function' not in self.extra and hasattr(sql_aggregate, 'sql_function'):
self.extra['sql_function'] = sql_aggregate.sql_function
self.extra['function'] = sql_aggregate.sql_function

if hasattr(sql_aggregate, 'sql_template'):
self.extra['sql_template'] = sql_aggregate.sql_template
self.extra['template'] = sql_aggregate.sql_template


class Avg(Aggregate):
is_computed = True
sql_function = 'AVG'
function = 'AVG'
name = 'Avg'


class Count(Aggregate):
is_ordinal = True
sql_function = 'COUNT'
function = 'COUNT'
name = 'Count'
sql_template = '%(function)s(%(distinct)s%(field)s)'
template = '%(function)s(%(distinct)s%(expressions)s)'

def __init__(self, expression, distinct=False, **extra):
if expression == '*':
Expand All @@ -111,12 +115,12 @@ def __init__(self, expression, distinct=False, **extra):


class Max(Aggregate):
sql_function = 'MAX'
function = 'MAX'
name = 'Max'


class Min(Aggregate):
sql_function = 'MIN'
function = 'MIN'
name = 'Min'


Expand All @@ -125,12 +129,12 @@ class StdDev(Aggregate):
name = 'StdDev'

def __init__(self, expression, sample=False, **extra):
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
super(StdDev, self).__init__(expression, **extra)
self.sql_function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'


class Sum(Aggregate):
sql_function = 'SUM'
function = 'SUM'
name = 'Sum'


Expand All @@ -139,5 +143,5 @@ class Variance(Aggregate):
name = 'Variance'

def __init__(self, expression, sample=False, **extra):
self.function = 'VAR_SAMP' if sample else 'VAR_POP'
super(Variance, self).__init__(expression, **extra)
self.sql_function = 'VAR_SAMP' if sample else 'VAR_POP'

0 comments on commit 362fec1

Please sign in to comment.