Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Adding items from todo list, including inheritance and options testing

and topo sort
  • Loading branch information...
commit 54fcef2592fa31580ccbcbef53ce5b9ea894f939 1 parent e8a068d
@coleifer authored
Showing with 195 additions and 13 deletions.
  1. +0 −4 TODO.rst
  2. +30 −2 peewee.py
  3. +165 −7 tests.py
View
4 TODO.rst
@@ -1,10 +1,6 @@
todo
====
-* Q() with django syntax
-* inheritance test
-* model options test
-* topo sort
* backwards compat, esp places where existing api allows strings
* stronger input validation?
* docs
View
32 peewee.py
@@ -497,7 +497,7 @@ def __get__(self, instance, instance_type=None):
class ForeignKeyField(Field):
def __init__(self, rel_model, null=False, related_name=None, cascade=False, extra=None, *args, **kwargs):
self.rel_model = rel_model
- self.related_name = related_name
+ self._related_name = related_name
self.cascade = cascade
self.extra = extra
@@ -516,7 +516,7 @@ def add_to_class(self, model_class, name):
model_class._meta.fields[self.name] = self
model_class._meta.columns[self.db_column] = self
- self.related_name = self.related_name or '%s_set' % (model_class._meta.name)
+ self.related_name = self._related_name or '%s_set' % (model_class._meta.name)
if self.rel_model == 'self':
self.rel_model = self.model_class
@@ -1976,3 +1976,31 @@ def __eq__(self, other):
def __ne__(self, other):
return not self == other
+
+
+def create_model_tables(models, **create_table_kwargs):
+ """Create tables for all given models (in the right order)."""
+ for m in sort_models_topologically(models):
+ m.create_table(**create_table_kwargs)
+
+def drop_model_tables(models, **drop_table_kwargs):
+ """Drop tables for all given models (in the right order)."""
+ for m in reversed(sort_models_topologically(models)):
+ m.drop_table(**drop_table_kwargs)
+
+def sort_models_topologically(models):
+ """Sort models topologically so that parents will precede children."""
+ models = set(models)
+ seen = set()
+ ordering = []
+ def dfs(model):
+ if model in models and model not in seen:
+ seen.add(model)
+ for child_model in model._meta.reverse_rel.values():
+ dfs(child_model)
+ ordering.append(model) # parent will follow descendants
+ # order models by name and table initially to guarantee a total ordering
+ names = lambda m: (m._meta.name, m._meta.db_table)
+ for m in sorted(models, key=names, reverse=True):
+ dfs(m)
+ return list(reversed(ordering)) # want parents first in output ordering
View
172 tests.py
@@ -94,7 +94,7 @@ class Meta:
db_table = 'users'
class Blog(TestModel):
- user = ForeignKeyField(User, related_name='blogs')
+ user = ForeignKeyField(User)
title = CharField(max_length=25)
content = TextField(default='')
pub_date = DateTimeField(null=True)
@@ -178,9 +178,13 @@ class Meta:
(('f2', 'f3'), False),
)
+class BlogTwo(Blog):
+ title = TextField()
+ extra_field = CharField()
+
MODELS = [User, Blog, Comment, Relationship, NullModel, UniqueModel, OrderedModel, Category, UserCategory,
- NonIntModel, NonIntRelModel, DBUser, DBBlog, SeqModelA, SeqModelB, MultiIndexModel]
+ NonIntModel, NonIntRelModel, DBUser, DBBlog, SeqModelA, SeqModelB, MultiIndexModel, BlogTwo]
INT = test_db.interpolation
def drop_tables(only=None):
@@ -762,8 +766,8 @@ def test_related_name(self):
b12 = Blog.create(user=u1, title='b12')
b2 = Blog.create(user=u2, title='b2')
- self.assertEqual([b.title for b in u1.blogs], ['b11', 'b12'])
- self.assertEqual([b.title for b in u2.blogs], ['b2'])
+ self.assertEqual([b.title for b in u1.blog_set], ['b11', 'b12'])
+ self.assertEqual([b.title for b in u2.blog_set], ['b2'])
def test_fk_exceptions(self):
c1 = Category.create(name='c1')
@@ -1408,9 +1412,116 @@ def reader_thread(q, num):
self.assertEqual(data_queue.qsize(), 100)
-class ModelInheritanceTestCase(BasePeeweeTestCase):
- # TODO
- pass
+class ModelOptionInheritanceTestCase(BasePeeweeTestCase):
+ def test_db_table(self):
+ self.assertEqual(User._meta.db_table, 'users')
+
+ class Foo(TestModel):
+ pass
+ self.assertEqual(Foo._meta.db_table, 'foo')
+
+ class Foo2(TestModel):
+ pass
+ self.assertEqual(Foo2._meta.db_table, 'foo2')
+
+ class Foo_3(TestModel):
+ pass
+ self.assertEqual(Foo_3._meta.db_table, 'foo_3')
+
+ def test_option_inheritance(self):
+ x_test_db = SqliteDatabase('testing.db')
+ child2_db = SqliteDatabase('child2.db')
+
+ class FakeUser(Model):
+ pass
+
+ class ParentModel(Model):
+ title = CharField()
+ user = ForeignKeyField(FakeUser)
+
+ class Meta:
+ database = x_test_db
+
+ class ChildModel(ParentModel):
+ pass
+
+ class ChildModel2(ParentModel):
+ special_field = CharField()
+
+ class Meta:
+ database = child2_db
+
+ class GrandChildModel(ChildModel):
+ pass
+
+ class GrandChildModel2(ChildModel2):
+ special_field = TextField()
+
+ self.assertEqual(ParentModel._meta.database.database, 'testing.db')
+ self.assertEqual(ParentModel._meta.model_class, ParentModel)
+
+ self.assertEqual(ChildModel._meta.database.database, 'testing.db')
+ self.assertEqual(ChildModel._meta.model_class, ChildModel)
+ self.assertEqual(sorted(ChildModel._meta.fields.keys()), [
+ 'id', 'title', 'user'
+ ])
+
+ self.assertEqual(ChildModel2._meta.database.database, 'child2.db')
+ self.assertEqual(ChildModel2._meta.model_class, ChildModel2)
+ self.assertEqual(sorted(ChildModel2._meta.fields.keys()), [
+ 'id', 'special_field', 'title', 'user'
+ ])
+
+ self.assertEqual(GrandChildModel._meta.database.database, 'testing.db')
+ self.assertEqual(GrandChildModel._meta.model_class, GrandChildModel)
+ self.assertEqual(sorted(GrandChildModel._meta.fields.keys()), [
+ 'id', 'title', 'user'
+ ])
+
+ self.assertEqual(GrandChildModel2._meta.database.database, 'child2.db')
+ self.assertEqual(GrandChildModel2._meta.model_class, GrandChildModel2)
+ self.assertEqual(sorted(GrandChildModel2._meta.fields.keys()), [
+ 'id', 'special_field', 'title', 'user'
+ ])
+ self.assertTrue(isinstance(GrandChildModel2._meta.fields['special_field'], TextField))
+
+
+class ModelInheritanceTestCase(ModelTestCase):
+ requires = [Blog, BlogTwo, User]
+
+ def test_model_inheritance_attrs(self):
+ self.assertEqual(Blog._meta.get_field_names(), ['pk', 'user', 'title', 'content', 'pub_date'])
+ self.assertEqual(BlogTwo._meta.get_field_names(), ['id', 'user', 'content', 'pub_date', 'title', 'extra_field'])
+
+ self.assertEqual(Blog._meta.primary_key.name, 'pk')
+ self.assertEqual(BlogTwo._meta.primary_key.name, 'id')
+
+ self.assertEqual(Blog.user.related_name, 'blog_set')
+ self.assertEqual(BlogTwo.user.related_name, 'blogtwo_set')
+
+ self.assertEqual(User.blog_set.rel_model, Blog)
+ self.assertEqual(User.blogtwo_set.rel_model, BlogTwo)
+
+ self.assertFalse(BlogTwo._meta.db_table == Blog._meta.db_table)
+
+ def test_model_inheritance_flow(self):
+ u = User.create(username='u')
+
+ b = Blog.create(title='b', user=u)
+ b2 = BlogTwo.create(title='b2', extra_field='foo', user=u)
+
+ self.assertEqual(list(u.blog_set), [b])
+ self.assertEqual(list(u.blogtwo_set), [b2])
+
+ self.assertEqual(Blog.select().count(), 1)
+ self.assertEqual(BlogTwo.select().count(), 1)
+
+ b_from_db = Blog.get(pk=b.pk)
+ b2_from_db = BlogTwo.get(id=b2.id)
+
+ self.assertEqual(b_from_db.user, u)
+ self.assertEqual(b2_from_db.user, u)
+ self.assertEqual(b2_from_db.extra_field, 'foo')
class DatabaseTestCase(BasePeeweeTestCase):
@@ -1449,6 +1560,53 @@ def test_connection_state(self):
self.assertFalse(test_db.is_closed())
+class TopologicalSortTestCase(unittest.TestCase):
+ def test_topological_sort_fundamentals(self):
+ FKF = ForeignKeyField
+ # we will be topo-sorting the following models
+ class A(Model): pass
+ class B(Model): a = FKF(A) # must follow A
+ class C(Model): a, b = FKF(A), FKF(B) # must follow A and B
+ class D(Model): c = FKF(C) # must follow A and B and C
+ class E(Model): e = FKF('self')
+ # but excluding this model, which is a child of E
+ class Excluded(Model): e = FKF(E)
+
+ # property 1: output ordering must not depend upon input order
+ repeatable_ordering = None
+ for input_ordering in permutations([A, B, C, D, E]):
+ output_ordering = sort_models_topologically(input_ordering)
+ repeatable_ordering = repeatable_ordering or output_ordering
+ self.assertEqual(repeatable_ordering, output_ordering)
+
+ # property 2: output ordering must have same models as input
+ self.assertEqual(len(output_ordering), 5)
+ self.assertFalse(Excluded in output_ordering)
+
+ # property 3: parents must precede children
+ def assert_precedes(X, Y):
+ lhs, rhs = map(output_ordering.index, [X, Y])
+ self.assertTrue(lhs < rhs)
+ assert_precedes(A, B)
+ assert_precedes(B, C) # if true, C follows A by transitivity
+ assert_precedes(C, D) # if true, D follows A and B by transitivity
+
+ # property 4: independent model hierarchies must be in name order
+ assert_precedes(A, E)
+
+def permutations(xs):
+ if not xs:
+ yield []
+ else:
+ for y, ys in selections(xs):
+ for pys in permutations(ys):
+ yield [y] + pys
+
+def selections(xs):
+ for i in xrange(len(xs)):
+ yield (xs[i], xs[:i] + xs[i + 1:])
+
+
if test_db.for_update:
class ForUpdateTestCase(ModelTestCase):
requires = [User]
Please sign in to comment.
Something went wrong with that request. Please try again.