Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial commit (import from django-orm-extensions).
- Loading branch information
Andrey Antukh
committed
Sep 15, 2012
0 parents
commit e5c9019
Showing
12 changed files
with
552 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
*.pyc | ||
.*swp | ||
doc/build | ||
dist | ||
versiontools* | ||
build* | ||
*.egg* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
===================== | ||
djorm-ext-expressions | ||
===================== | ||
|
||
Django by default, provides a wide range of field types and generic lookups for queries. This in many cases is more than enough. But there are cases where you need to use types defined for yourself and search operators that are not defined in django lookups and another important case is to make searches requiring the execution of some function in WHERE clause. | ||
|
||
In django, for these last two cases, it requires writing SQL statements. ``djorm-ext-expressions`` introduces the method ``manager.where()`` and some class'es (SqlExpression, SqlFunction, AND, OR, ...) to facilite sql construction for advanced cases. | ||
|
||
Simple usage | ||
------------ | ||
|
||
Imagine some django model with postgresql integer array field. You need to find objects in the field containing a set of group numbers. | ||
|
||
**NOTE**: array field is part of django orm extensions package and is located on ``djorm-ext-pgarray`` submodule. | ||
|
||
**Example model definition** | ||
|
||
.. code-block:: python | ||
from django.db import models | ||
from djorm_expressions.models import ExpressionManager | ||
from .somefiels import ArrayField | ||
class Register(models.Model): | ||
name = models.CharField(max_length=200) | ||
points = ArrayField(dbtype="int") | ||
objects = ExpressionManager() | ||
With this model definition, we can do this searches:: | ||
|
||
from djorm_expressions.base import SqlExpression, AND, OR | ||
|
||
# search all register items that points field contains [2,3] | ||
|
||
qs = Register.manager.where( | ||
SqlExpression("points", "@>", [2,3]) | ||
) | ||
|
||
# search all register items that points fields contains [2,3] or [5,6] | ||
|
||
expression = OR( | ||
SqlExpression("points", "@>", [2,3]), | ||
SqlExpression("points", "@>", [5,6]), | ||
) | ||
|
||
qs = Register.objects.where(expression) | ||
|
||
|
||
Also, we can use functions to construct a expression:: | ||
|
||
from djorm_expressions.base import SqlFunction | ||
|
||
class BitLength(SqlFunction): | ||
sql_function = "bit_length" | ||
|
||
# search all registers items that bit_length(name) > 20. | ||
qs = Register.objects.where( | ||
SqlExpression(BitLength("name"), ">", 20) | ||
) | ||
|
||
|
||
I finally can redefine the behavior "SqlExpression" and make it more "object oriented":: | ||
|
||
class ArrayExpression(object): | ||
def __init__(self, field): | ||
self.field = field | ||
|
||
def contains(self, value): | ||
return SqlExpression(self.field, "@>", value) | ||
|
||
def overlap(self, value): | ||
return SqlExpression(self.field, "&&", value) | ||
|
||
# search all register items that points field contains [2,3] | ||
qs = Register.objects.where( | ||
ArrayExpression("points").contains([2,3]) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__version__ = (4, 0, 0, 'final', 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from .utils import _setup_joins_for_fields | ||
from .tree import AND, OR | ||
|
||
class SqlNode(object): | ||
negated = False | ||
|
||
sql_negated_template = "NOT %s" | ||
|
||
@property | ||
def field_parts(self): | ||
raise NotImplementedError | ||
|
||
def as_sql(self, qn, queryset): | ||
raise NotImplementedError | ||
|
||
def __invert__(self): | ||
# TODO: use clone insetead self modification. | ||
self.negated = True | ||
return self | ||
|
||
|
||
class SqlExpression(SqlNode): | ||
sql_template = "%(field)s %(operator)s %%s" | ||
|
||
def __init__(self, field_or_func, operator, value=None, **kwargs): | ||
self.operator = operator | ||
self.value = value | ||
self.extra = kwargs | ||
|
||
if isinstance(field_or_func, SqlNode): | ||
self.field = field_or_func.field | ||
self.sql_function = field_or_func | ||
else: | ||
self.field = field_or_func | ||
self.sql_function = None | ||
|
||
@property | ||
def field_parts(self): | ||
return self.field.split("__") | ||
|
||
def as_sql(self, qn, queryset): | ||
""" | ||
Return the statement rendered as sql. | ||
""" | ||
|
||
# setup joins if needed | ||
if self.sql_function is None: | ||
_setup_joins_for_fields(self.field_parts, self, queryset) | ||
|
||
# build sql | ||
params, args = {}, [] | ||
|
||
if self.operator is not None: | ||
params['operator'] = self.operator | ||
|
||
if self.sql_function is None: | ||
if isinstance(self.field, basestring): | ||
params['field'] = qn(self.field) | ||
elif isinstance(self.field, (tuple, list)): | ||
_tbl, _fld = self.field | ||
params['field'] = "%s.%s" % (qn(_tbl), qn(_fld)) | ||
else: | ||
raise ValueError("Invalid field value") | ||
else: | ||
params['field'], _args = self.sql_function.as_sql(qn, queryset) | ||
args.extend(_args) | ||
|
||
params.update(self.extra) | ||
if self.value is not None: | ||
args.extend([self.value]) | ||
|
||
template_result = self.sql_template % params | ||
|
||
if self.negated: | ||
return self.sql_negated_template % (template_result), args | ||
|
||
return template_result, args | ||
|
||
|
||
class RawExpression(SqlExpression): | ||
field_parts = [] | ||
|
||
def __init__(self, sqlstatement, *args): | ||
self.statement = sqlstatement | ||
self.params = args | ||
|
||
def as_sql(self, qn, queryset): | ||
return self.statement, self.params | ||
|
||
|
||
# TODO: add function(function()) feature. | ||
|
||
class SqlFunction(SqlNode): | ||
sql_template = '%(function)s(%(field)s)' | ||
sql_function = None | ||
args = [] | ||
|
||
def __init__(self, field, *args, **kwargs): | ||
self.field = field | ||
self.args = args | ||
self.extern_params = kwargs | ||
|
||
@property | ||
def field_parts(self): | ||
return self.field.split("__") | ||
|
||
def as_sql(self, qn, queryset): | ||
""" | ||
Return the aggregate/annotation rendered as sql. | ||
""" | ||
|
||
_setup_joins_for_fields(self.field_parts, self, queryset) | ||
|
||
params = {} | ||
if self.sql_function is not None: | ||
params['function'] = self.sql_function | ||
if isinstance(self.field, basestring): | ||
params['field'] = qn(self.field) | ||
elif isinstance(self.field, (tuple, list)): | ||
_tbl, _fld = self.field | ||
params['field'] = "%s.%s" % (qn(_tbl), qn(_fld)) | ||
else: | ||
raise ValueError("Invalid field value") | ||
|
||
params.update(self.extern_params) | ||
return self.sql_template % params, self.args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from django.utils.datastructures import SortedDict | ||
from django.db.models.sql.where import ExtraWhere | ||
from django.db.models.query import QuerySet | ||
from django.db import models | ||
|
||
from .base import AND | ||
|
||
|
||
class ExpressionQuerySetMixin(object): | ||
def annotate_functions(self, **kwargs): | ||
extra_select, params = SortedDict(), [] | ||
clone = self._clone() | ||
|
||
for alias, node in kwargs.iteritems(): | ||
_sql, _params = node.as_sql(self.quote_name, self) | ||
|
||
extra_select[alias] = _sql | ||
params.extend(_params) | ||
|
||
clone.query.add_extra(extra_select, params, None, None, None, None) | ||
return clone | ||
|
||
def where(self, *args): | ||
clone = self._clone() | ||
statement = AND(*args) | ||
|
||
_sql, _params = statement.as_sql(self.quote_name, clone) | ||
if hasattr(_sql, 'to_str'): | ||
_sql = _sql.to_str() | ||
|
||
clone.query.where.add(ExtraWhere([_sql], _params), "AND") | ||
return clone | ||
|
||
def quote_name(self, name): | ||
if name.startswith('"') and name.endswith('"'): | ||
return name # Quoting once is enough. | ||
return '"%s"' % name | ||
|
||
|
||
|
||
class ExpressionManagerMixin(object): | ||
def annotate_functions(self, **kwargs): | ||
return self.get_query_set().annotate_functions(**kwargs) | ||
|
||
def where(self, *args): | ||
return self.get_query_set().where(*args) | ||
|
||
|
||
class ExpressionQuerySet(ExpressionQuerySetMixin, QuerySet): | ||
""" | ||
Predefined expression queryset. Usefull if you only use expresions. | ||
""" | ||
pass | ||
|
||
|
||
class ExpressionManager(ExpressionManagerMixin, models.Manager): | ||
""" | ||
Prededined expression manager what uses `ExpressionQuerySet`. | ||
""" | ||
|
||
use_for_related_fields = True | ||
|
||
def get_query_set(self): | ||
return ExpressionQuerySet(model=self.model, using=self._db) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from django.test import TestCase | ||
|
||
from djorm_expressions.base import RawExpression, SqlExpression, SqlFunction, AND, OR | ||
from .models import Person, Profile | ||
|
||
class BitLength(SqlFunction): | ||
sql_function = "bit_length" | ||
|
||
|
||
class SqlExpressionsTests(TestCase): | ||
def setUp(self): | ||
Person.objects.all().delete() | ||
|
||
def test_raw_statements_0(self): | ||
expresion_instance = OR( | ||
AND( | ||
RawExpression("name = %s", "Foo"), | ||
RawExpression("age = %s", 14), | ||
), | ||
AND( | ||
RawExpression("name = %s", "Bar"), | ||
RawExpression("age = %s", 14), | ||
) | ||
) | ||
sql, params = expresion_instance.as_sql(None, None) | ||
self.assertEqual(sql.to_str(), "(name = %s AND age = %s) OR (name = %s AND age = %s)") | ||
self.assertEqual(params, ['Foo', 14, 'Bar', 14]) | ||
|
||
|
||
def test_string_sample_statement(self): | ||
obj = Person.objects.create(name="jose") | ||
|
||
queryset = Person.objects.where( | ||
SqlExpression(BitLength("name"), "=", 32) | ||
) | ||
self.assertEqual(queryset.count(), 1) | ||
|
||
def test_join_lookup_with_expression(self): | ||
person = Person.objects.create(name="jose") | ||
profile = Profile.objects.create(person=person) | ||
|
||
queryset = Profile.objects.where( | ||
SqlExpression(BitLength("person__name"), "=", 32) | ||
) | ||
self.assertEqual(queryset.count(), 1) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from django.db import models | ||
|
||
from ..models import ExpressionManager | ||
|
||
class Person(models.Model): | ||
name = models.CharField(max_length=200) | ||
objects = ExpressionManager() | ||
|
||
|
||
class Profile(models.Model): | ||
person = models.ForeignKey("Person", related_name="profiles") | ||
objects = ExpressionManager() |
Oops, something went wrong.