Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Adding items from todo list, including inheritance and options testing

and topo sort
  • Loading branch information...
commit 54fcef2592fa31580ccbcbef53ce5b9ea894f939 1 parent e8a068d
@coleifer authored
Showing with 195 additions and 13 deletions.
  1. +0 −4 TODO.rst
  2. +30 −2 peewee.py
  3. +165 −7 tests.py
View
4 TODO.rst
@@ -1,10 +1,6 @@
todo
====
-* Q() with django syntax
-* inheritance test
-* model options test
-* topo sort
* backwards compat, esp places where existing api allows strings
* stronger input validation?
* docs
View
32 peewee.py
@@ -497,7 +497,7 @@ def __get__(self, instance, instance_type=None):
class ForeignKeyField(Field):
def __init__(self, rel_model, null=False, related_name=None, cascade=False, extra=None, *args, **kwargs):
self.rel_model = rel_model
- self.related_name = related_name
+ self._related_name = related_name
self.cascade = cascade
self.extra = extra
@@ -516,7 +516,7 @@ def add_to_class(self, model_class, name):
model_class._meta.fields[self.name] = self
model_class._meta.columns[self.db_column] = self
- self.related_name = self.related_name or '%s_set' % (model_class._meta.name)
+ self.related_name = self._related_name or '%s_set' % (model_class._meta.name)
if self.rel_model == 'self':
self.rel_model = self.model_class
@@ -1976,3 +1976,31 @@ def __eq__(self, other):
def __ne__(self, other):
return not self == other
+
+
+def create_model_tables(models, **create_table_kwargs):
+ """Create tables for all given models (in the right order)."""
+ for m in sort_models_topologically(models):
+ m.create_table(**create_table_kwargs)
+
+def drop_model_tables(models, **drop_table_kwargs):
+ """Drop tables for all given models (in the right order)."""
+ for m in reversed(sort_models_topologically(models)):
+ m.drop_table(**drop_table_kwargs)
+
+def sort_models_topologically(models):
+ """Sort models topologically so that parents will precede children."""
+ models = set(models)
+ seen = set()
+ ordering = []
+ def dfs(model):
+ if model in models and model not in seen:
+ seen.add(model)
+ for child_model in model._meta.reverse_rel.values():
+ dfs(child_model)
+ ordering.append(model) # parent will follow descendants
+ # order models by name and table initially to guarantee a total ordering
+ names = lambda m: (m._meta.name, m._meta.db_table)
+ for m in sorted(models, key=names, reverse=True):
+ dfs(m)
+ return list(reversed(ordering)) # want parents first in output ordering
View
172 tests.py
@@ -94,7 +94,7 @@ class Meta:
db_table = 'users'
class Blog(TestModel):
- user = ForeignKeyField(User, related_name='blogs')
+ user = ForeignKeyField(User)
title = CharField(max_length=25)
content = TextField(default='')
pub_date = DateTimeField(null=True)
@@ -178,9 +178,13 @@ class Meta:
(('f2', 'f3'), False),
)
+class BlogTwo(Blog):
+ title = TextField()
+ extra_field = CharField()
+
MODELS = [User, Blog, Comment, Relationship, NullModel, UniqueModel, OrderedModel, Category, UserCategory,
- NonIntModel, NonIntRelModel, DBUser, DBBlog, SeqModelA, SeqModelB, MultiIndexModel]
+ NonIntModel, NonIntRelModel, DBUser, DBBlog, SeqModelA, SeqModelB, MultiIndexModel, BlogTwo]
INT = test_db.interpolation
def drop_tables(only=None):
@@ -762,8 +766,8 @@ def test_related_name(self):
b12 = Blog.create(user=u1, title='b12')
b2 = Blog.create(user=u2, title='b2')
- self.assertEqual([b.title for b in u1.blogs], ['b11', 'b12'])
- self.assertEqual([b.title for b in u2.blogs], ['b2'])
+ self.assertEqual([b.title for b in u1.blog_set], ['b11', 'b12'])
+ self.assertEqual([b.title for b in u2.blog_set], ['b2'])
def test_fk_exceptions(self):
c1 = Category.create(name='c1')
@@ -1408,9 +1412,116 @@ def reader_thread(q, num):
self.assertEqual(data_queue.qsize(), 100)
-class ModelInheritanceTestCase(BasePeeweeTestCase):
- # TODO
- pass
+class ModelOptionInheritanceTestCase(BasePeeweeTestCase):
+ def test_db_table(self):
+ self.assertEqual(User._meta.db_table, 'users')
+
+ class Foo(TestModel):
+ pass
+ self.assertEqual(Foo._meta.db_table, 'foo')
+
+ class Foo2(TestModel):
+ pass
+ self.assertEqual(Foo2._meta.db_table, 'foo2')
+
+ class Foo_3(TestModel):
+ pass
+ self.assertEqual(Foo_3._meta.db_table, 'foo_3')
+
+ def test_option_inheritance(self):
+ x_test_db = SqliteDatabase('testing.db')
+ child2_db = SqliteDatabase('child2.db')
+
+ class FakeUser(Model):
+ pass
+
+ class ParentModel(Model):
+ title = CharField()
+ user = ForeignKeyField(FakeUser)
+
+ class Meta:
+ database = x_test_db
+
+ class ChildModel(ParentModel):
+ pass
+
+ class ChildModel2(ParentModel):
+ special_field = CharField()
+
+ class Meta:
+ database = child2_db
+
+ class GrandChildModel(ChildModel):
+ pass
+
+ class GrandChildModel2(ChildModel2):
+ special_field = TextField()
+
+ self.assertEqual(ParentModel._meta.database.database, 'testing.db')
+ self.assertEqual(ParentModel._meta.model_class, ParentModel)
+
+ self.assertEqual(ChildModel._meta.database.database, 'testing.db')
+ self.assertEqual(ChildModel._meta.model_class, ChildModel)
+ self.assertEqual(sorted(ChildModel._meta.fields.keys()), [
+ 'id', 'title', 'user'
+ ])
+
+ self.assertEqual(ChildModel2._meta.database.database, 'child2.db')
+ self.assertEqual(ChildModel2._meta.model_class, ChildModel2)
+ self.assertEqual(sorted(ChildModel2._meta.fields.keys()), [
+ 'id', 'special_field', 'title', 'user'
+ ])
+
+ self.assertEqual(GrandChildModel._meta.database.database, 'testing.db')
+ self.assertEqual(GrandChildModel._meta.model_class, GrandChildModel)
+ self.assertEqual(sorted(GrandChildModel._meta.fields.keys()), [
+ 'id', 'title', 'user'
+ ])
+
+ self.assertEqual(GrandChildModel2._meta.database.database, 'child2.db')
+ self.assertEqual(GrandChildModel2._meta.model_class, GrandChildModel2)
+ self.assertEqual(sorted(GrandChildModel2._meta.fields.keys()), [
+ 'id', 'special_field', 'title', 'user'
+ ])
+ self.assertTrue(isinstance(GrandChildModel2._meta.fields['special_field'], TextField))
+
+
+class ModelInheritanceTestCase(ModelTestCase):
+ requires = [Blog, BlogTwo, User]
+
+ def test_model_inheritance_attrs(self):
+ self.assertEqual(Blog._meta.get_field_names(), ['pk', 'user', 'title', 'content', 'pub_date'])
+ self.assertEqual(BlogTwo._meta.get_field_names(), ['id', 'user', 'content', 'pub_date', 'title', 'extra_field'])
+
+ self.assertEqual(Blog._meta.primary_key.name, 'pk')
+ self.assertEqual(BlogTwo._meta.primary_key.name, 'id')
+
+ self.assertEqual(Blog.user.related_name, 'blog_set')
+ self.assertEqual(BlogTwo.user.related_name, 'blogtwo_set')
+
+ self.assertEqual(User.blog_set.rel_model, Blog)
+ self.assertEqual(User.blogtwo_set.rel_model, BlogTwo)
+
+ self.assertFalse(BlogTwo._meta.db_table == Blog._meta.db_table)
+
+ def test_model_inheritance_flow(self):
+ u = User.create(username='u')
+
+ b = Blog.create(title='b', user=u)
+ b2 = BlogTwo.create(title='b2', extra_field='foo', user=u)
+
+ self.assertEqual(list(u.blog_set), [b])
+ self.assertEqual(list(u.blogtwo_set), [b2])
+
+ self.assertEqual(Blog.select().count(), 1)
+ self.assertEqual(BlogTwo.select().count(), 1)
+
+ b_from_db = Blog.get(pk=b.pk)
+ b2_from_db = BlogTwo.get(id=b2.id)
+
+ self.assertEqual(b_from_db.user, u)
+ self.assertEqual(b2_from_db.user, u)
+ self.assertEqual(b2_from_db.extra_field, 'foo')
class DatabaseTestCase(BasePeeweeTestCase):
@@ -1449,6 +1560,53 @@ def test_connection_state(self):
self.assertFalse(test_db.is_closed())
+class TopologicalSortTestCase(unittest.TestCase):
+ def test_topological_sort_fundamentals(self):
+ FKF = ForeignKeyField
+ # we will be topo-sorting the following models
+ class A(Model): pass
+ class B(Model): a = FKF(A) # must follow A
+ class C(Model): a, b = FKF(A), FKF(B) # must follow A and B
+ class D(Model): c = FKF(C) # must follow A and B and C
+ class E(Model): e = FKF('self')
+ # but excluding this model, which is a child of E
+ class Excluded(Model): e = FKF(E)
+
+ # property 1: output ordering must not depend upon input order
+ repeatable_ordering = None
+ for input_ordering in permutations([A, B, C, D, E]):
+ output_ordering = sort_models_topologically(input_ordering)
+ repeatable_ordering = repeatable_ordering or output_ordering
+ self.assertEqual(repeatable_ordering, output_ordering)
+
+ # property 2: output ordering must have same models as input
+ self.assertEqual(len(output_ordering), 5)
+ self.assertFalse(Excluded in output_ordering)
+
+ # property 3: parents must precede children
+ def assert_precedes(X, Y):
+ lhs, rhs = map(output_ordering.index, [X, Y])
+ self.assertTrue(lhs < rhs)
+ assert_precedes(A, B)
+ assert_precedes(B, C) # if true, C follows A by transitivity
+ assert_precedes(C, D) # if true, D follows A and B by transitivity
+
+ # property 4: independent model hierarchies must be in name order
+ assert_precedes(A, E)
+
+def permutations(xs):
+ if not xs:
+ yield []
+ else:
+ for y, ys in selections(xs):
+ for pys in permutations(ys):
+ yield [y] + pys
+
+def selections(xs):
+ for i in xrange(len(xs)):
+ yield (xs[i], xs[:i] + xs[i + 1:])
+
+
if test_db.for_update:
class ForUpdateTestCase(ModelTestCase):
requires = [User]
Please sign in to comment.
Something went wrong with that request. Please try again.