Permalink
Browse files

Cleaning up some parsing bits, more tests, moving asc/desc into Expr

  • Loading branch information...
1 parent 3b11206 commit 0f4a0059d158196926258a189160ce03d75b2b7d @coleifer committed Oct 2, 2012
Showing with 146 additions and 33 deletions.
  1. +32 −13 peewee.py
  2. +114 −20 tests.py
View
@@ -171,10 +171,17 @@ def set_alias(self, a=None):
self.alias = a
return self
+ def asc(self):
+ return Ordering(self, True)
+
+ def desc(self):
+ return Ordering(self, False)
+
def _expr(op, n=False):
def inner(self, value):
return BinaryExpr(self, op, value)
return inner
+
__add__ = _expr(OP_ADD)
__sub__ = _expr(OP_SUB)
__mul__ = _expr(OP_MUL)
@@ -236,7 +243,8 @@ def __set__(self, instance, value):
instance._data[self.att_name] = value
-Ordering = namedtuple('Ordering', ('field', 'asc'))
+Ordering = namedtuple('Ordering', ('param', 'asc'))
+R = namedtuple('R', ('value',))
class Field(Expr):
@@ -291,12 +299,6 @@ def db_value(self, value):
def python_value(self, value):
return value if value is None else self.coerce(value)
- def asc(self):
- return Ordering(self, True)
-
- def desc(self):
- return Ordering(self, False)
-
class IntegerField(Field):
db_field = 'int'
@@ -606,7 +608,7 @@ def parse_expr(self, expr, alias_map=None):
if isinstance(expr, Field):
return self._parse_field(expr, alias_map)
elif isinstance(expr, Ordering):
- expr_str, params = self._parse_field(expr.field, alias_map)
+ expr_str, params = self.parse_expr(expr.param, alias_map)
expr_str += ' ASC' if expr.asc else ' DESC'
return expr_str, params
elif isinstance(expr, Func):
@@ -618,6 +620,8 @@ def parse_expr(self, expr, alias_map=None):
scalars.extend(params)
expr_str = '%s(%s)' % (expr.fn_name, ', '.join(exprs))
return self._add_alias(expr_str, expr), scalars
+ elif isinstance(expr, R):
+ return expr.value, []
elif isinstance(expr, SelectQuery):
max_alias = self._max_alias(alias_map)
clone = expr.clone()
@@ -674,10 +678,12 @@ def parse_query_node(self, qnode, alias_map):
def parse_joins(self, joins, model_class, alias_map):
parsed = []
+ seen = set()
def _traverse(curr):
- if curr not in joins:
+ if curr not in joins or curr in seen:
return
+ seen.add(curr)
for join in joins[curr]:
from_model = curr
to_model = join.model_class
@@ -1639,7 +1645,7 @@ class EmptyResultException(Exception):
class ModelOptions(object):
def __init__(self, cls, database=None, db_table=None, indexes=None,
- ordering=None, primary_key=None):
+ order_by=None, primary_key=None):
self.model_class = cls
self.name = cls.__name__.lower()
self.fields = {}
@@ -1649,7 +1655,7 @@ def __init__(self, cls, database=None, db_table=None, indexes=None,
self.database = database or default_database
self.db_table = db_table
self.indexes = indexes or []
- self.ordering = ordering
+ self.order_by = order_by
self.primary_key = primary_key
self.auto_increment = None
@@ -1661,6 +1667,16 @@ def prepared(self):
if field.default is not None:
self.defaults[field] = field.default
+ if self.order_by:
+ norm_order_by = []
+ for clause in self.order_by:
+ field = self.fields[clause.lstrip('-')]
+ if clause.startswith('-'):
+ norm_order_by.append(field.desc())
+ else:
+ norm_order_by.append(field.asc())
+ self.order_by = norm_order_by
+
def get_default_dict(self):
return dict((f, dft if not callable(dft) else dft()) for f, dft in self.defaults.items())
@@ -1687,7 +1703,7 @@ def rel_exists(self, model):
class BaseModel(type):
- inheritable_options = ['database', 'indexes', 'ordering', 'primary_key']
+ inheritable_options = ['database', 'indexes', 'order_by', 'primary_key']
def __new__(cls, name, bases, attrs):
if not bases:
@@ -1768,7 +1784,10 @@ def raw(cls, sql, *params):
@classmethod
def select(cls, *selection):
- return SelectQuery(cls, *selection)
+ query = SelectQuery(cls, *selection)
+ if cls._meta.order_by:
+ query = query.order_by(*cls._meta.order_by)
+ return query
@classmethod
def update(cls, **update):
View
@@ -93,7 +93,7 @@ class Meta:
db_table = 'users'
class Blog(TestModel):
- user = ForeignKeyField(User)
+ user = ForeignKeyField(User, related_name='blogs')
title = CharField(max_length=25)
content = TextField(default='')
pub_date = DateTimeField(null=True)
@@ -102,6 +102,10 @@ class Blog(TestModel):
def __unicode__(self):
return '%s: %s' % (self.user.username, self.title)
+class Comment(TestModel):
+ blog = ForeignKeyField(Blog, related_name='comments')
+ comment = CharField()
+
class Relationship(TestModel):
from_user = ForeignKeyField(User, related_name='relationships')
to_user = ForeignKeyField(User, related_name='related_to')
@@ -128,13 +132,13 @@ class OrderedModel(TestModel):
created = DateTimeField(default=datetime.datetime.now)
class Meta:
- ordering = (('created', 'desc'),)
+ order_by = ('-created',)
class Category(TestModel):
parent = ForeignKeyField('self', related_name='children', null=True)
name = CharField()
-MODELS = [User, Blog, Relationship, NullModel, UniqueModel, OrderedModel, Category]
+MODELS = [User, Blog, Comment, Relationship, NullModel, UniqueModel, OrderedModel, Category]
def drop_tables(only=None):
for model in reversed(MODELS):
@@ -234,6 +238,23 @@ def test_joins(self):
sq = SelectQuery(User).join(Relationship, JOIN_LEFT_OUTER, Relationship.to_user)
self.assertJoins(sq, ['LEFT OUTER JOIN "relationship" AS relationship ON users."id" = relationship."to_user_id"'])
+ def test_join_self_referential(self):
+ sq = SelectQuery(Category).join(Category)
+ self.assertJoins(sq, ['INNER JOIN "category" AS category ON category."parent_id" = category."id"'])
+
+ def test_join_both_sides(self):
+ sq = SelectQuery(Blog).join(Comment).switch(Blog).join(User)
+ self.assertJoins(sq, [
+ 'INNER JOIN "comment" AS comment ON blog."pk" = comment."blog_id"',
+ 'INNER JOIN "users" AS users ON blog."user_id" = users."id"',
+ ])
+
+ sq = SelectQuery(Blog).join(User).switch(Blog).join(Comment)
+ self.assertJoins(sq, [
+ 'INNER JOIN "users" AS users ON blog."user_id" = users."id"',
+ 'INNER JOIN "comment" AS comment ON blog."pk" = comment."blog_id"',
+ ])
+
def test_where(self):
sq = SelectQuery(User).where(User.id < 5)
self.assertWhere(sq, 'users."id" < ?', [5])
@@ -320,6 +341,41 @@ def test_having(self):
)
self.assertHaving(sq, '(Count(blog."pk") > ? OR Count(blog."pk") < ?)', [10, 2])
+ def test_ordering(self):
+ sq = SelectQuery(User).join(Blog).order_by(Blog.title)
+ self.assertOrderBy(sq, 'blog."title"', [])
+
+ sq = SelectQuery(User).join(Blog).order_by(Blog.title.asc())
+ self.assertOrderBy(sq, 'blog."title" ASC', [])
+
+ sq = SelectQuery(User).join(Blog).order_by(Blog.title.desc())
+ self.assertOrderBy(sq, 'blog."title" DESC', [])
+
+ sq = SelectQuery(User).join(Blog).order_by(User.username.desc(), Blog.title.asc())
+ self.assertOrderBy(sq, 'users."username" DESC, blog."title" ASC', [])
+
+ base_sq = SelectQuery(User, User.username, fn.Count(Blog.pk).set_alias('count')).join(Blog).group_by(User.username)
+ sq = base_sq.order_by(fn.Count(Blog.pk).desc())
+ self.assertOrderBy(sq, 'Count(blog."pk") DESC', [])
+
+ sq = base_sq.order_by(R('count'))
+ self.assertOrderBy(sq, 'count', [])
+
+ sq = OrderedModel.select()
+ self.assertOrderBy(sq, 'orderedmodel."created" DESC', [])
+
+ sq = OrderedModel.select().order_by(OrderedModel.id.asc())
+ self.assertOrderBy(sq, 'orderedmodel."id" ASC', [])
+
+ def test_paginate(self):
+ sq = SelectQuery(User).paginate(1, 20)
+ self.assertEqual(sq._limit, 20)
+ self.assertEqual(sq._offset, 0)
+
+ sq = SelectQuery(User).paginate(3, 30)
+ self.assertEqual(sq._limit, 30)
+ self.assertEqual(sq._offset, 60)
+
class UpdateTestCase(BasePeeweeTestCase):
def test_update(self):
uq = UpdateQuery(User, {User.username: 'updated'})
@@ -351,6 +407,10 @@ def test_raw(self):
rq = RawQuery(User, q, 100)
self.assertEqual(rq.sql(compiler), (q, [100]))
+class SugarTestCase(BasePeeweeTestCase):
+ # test things like filter, annotate, aggregate
+ pass
+
#
# TEST CASE USED TO PROVIDE ACCESS TO DATABASE
# FOR EXECUTION OF "LIVE" QUERIES
@@ -366,28 +426,15 @@ def setUp(self):
def create_user(self, username):
return User.create(username=username)
-
+
def create_users(self, n):
for i in range(n):
self.create_user('u%d' % (i + 1))
-class ModelAPITestCase(ModelTestCase):
- requires = [User, Blog]
-
- def test_creation(self):
- self.create_users(10)
- self.assertEqual(User.select().count(), 10)
+class QueryResultWrapperTestCase(ModelTestCase):
+ requires = [User]
- def test_saving(self):
- self.assertEqual(User.select().count(), 0)
-
- u = User(username='u1')
- u.save()
- u.save()
-
- self.assertEqual(User.select().count(), 1)
-
def test_iteration(self):
self.create_users(10)
query_start = len(self.queries())
@@ -406,10 +453,52 @@ def test_iteration(self):
another_iter = [u.username for u in qr]
self.assertEqual(another_iter, ['u%d' % i for i in range(1, 11)])
-
+
# only 1 query for these iterations
self.assertEqual(len(self.queries()) - query_start, 1)
+
+class ModelQueryTestCase(ModelTestCase):
+ requires = [User, Blog]
+
+ def test_select(self):
+ pass
+ def test_update(self):
+ pass
+ def test_insert(self):
+ pass
+ def test_delete(self):
+ pass
+ def test_raw(self):
+ pass
+
+
+class ModelAPITestCase(ModelTestCase):
+ requires = [User, Blog]
+
+ def test_related_name(self):
+ u1 = self.create_user('u1')
+ u2 = self.create_user('u2')
+ b11 = Blog.create(user=u1, title='b11')
+ b12 = Blog.create(user=u1, title='b12')
+ b2 = Blog.create(user=u2, title='b2')
+
+ self.assertEqual([b.title for b in u1.blogs], ['b11', 'b12'])
+ self.assertEqual([b.title for b in u2.blogs], ['b2'])
+
+ def test_creation(self):
+ self.create_users(10)
+ self.assertEqual(User.select().count(), 10)
+
+ def test_saving(self):
+ self.assertEqual(User.select().count(), 0)
+
+ u = User(username='u1')
+ u.save()
+ u.save()
+
+ self.assertEqual(User.select().count(), 1)
+
def test_reading(self):
u1 = self.create_user('u1')
u2 = self.create_user('u2')
@@ -444,3 +533,8 @@ def test_counting(self):
uc = User.select().where(User.username == 'u1').join(Blog).distinct().count()
self.assertEqual(uc, 1)
+
+ def test_exists(self):
+ u1 = User.create(username='u1')
+ self.assertTrue(User.select().where(User.username == 'u1').exists())
+ self.assertFalse(User.select().where(User.username == 'u2').exists())

0 comments on commit 0f4a005

Please sign in to comment.