Skip to content

Commit

Permalink
Cleanups to variable names in closure table.
Browse files Browse the repository at this point in the history
  • Loading branch information
coleifer committed Mar 5, 2017
1 parent def05be commit 41c92a7
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 15 deletions.
30 changes: 15 additions & 15 deletions playhouse/sqlite_ext.py
Expand Up @@ -667,7 +667,8 @@ class Meta:
return getattr(cls, attr)


def ClosureTable(model_class, foreign_key=None, referencing_class=None, id_column=None):
def ClosureTable(model_class, foreign_key=None, referencing_class=None,
referencing_key=None):
"""Model factory for the transitive closure extension."""
if referencing_class is None:
referencing_class = model_class
Expand All @@ -680,10 +681,9 @@ def ClosureTable(model_class, foreign_key=None, referencing_class=None, id_colum
else:
raise ValueError('Unable to find self-referential foreign key.')

primary_key = model_class._meta.primary_key

if id_column is None:
id_column = primary_key
source_key = model_class._meta.primary_key
if referencing_key is None:
referencing_key = source_key

class BaseClosureTable(VirtualModel):
depth = VirtualIntegerField()
Expand All @@ -700,7 +700,7 @@ class Meta:
def descendants(cls, node, depth=None, include_node=False):
query = (model_class
.select(model_class, cls.depth.alias('depth'))
.join(cls, on=(primary_key == cls.id))
.join(cls, on=(source_key == cls.id))
.where(cls.root == node)
.naive())
if depth is not None:
Expand All @@ -713,7 +713,7 @@ def descendants(cls, node, depth=None, include_node=False):
def ancestors(cls, node, depth=None, include_node=False):
query = (model_class
.select(model_class, cls.depth.alias('depth'))
.join(cls, on=(primary_key == cls.root))
.join(cls, on=(source_key == cls.root))
.where(cls.id == node)
.naive())
if depth:
Expand All @@ -731,26 +731,26 @@ def siblings(cls, node, include_node=False):
else:
# siblings as given in reference_class
siblings = (referencing_class
.select(id_column)
.join(cls, on=(foreign_key == cls.root))
.where((cls.id == node) & (cls.depth == 1)))
.select(referencing_key)
.join(cls, on=(foreign_key == cls.root))
.where((cls.id == node) & (cls.depth == 1)))

# the according models
query = (model_class
.select()
.where(primary_key << siblings)
.naive())
.select()
.where(source_key << siblings)
.naive())

if not include_node:
query = query.where(primary_key != node)
query = query.where(source_key != node)

return query

class Meta:
database = referencing_class._meta.database
extension_options = {
'tablename': referencing_class._meta.db_table,
'idcolumn': id_column.db_column,
'idcolumn': referencing_key.db_column,
'parentcolumn': foreign_key.db_column}
primary_key = False

Expand Down
55 changes: 55 additions & 0 deletions playhouse/tests/test_sqlite_ext.py
Expand Up @@ -1030,6 +1030,61 @@ def test_clean_query(self):
self.assertEqual(FTS5Model.clean_query(inval, '_'), outval)


@skip_if(lambda: not CLOSURE_EXTENSION)
class TestTransitiveClosureManyToMany(PeeweeTestCase):
def setUp(self):
super(TestTransitiveClosureManyToMany, self).setUp()
ext_db.load_extension(CLOSURE_EXTENSION.rstrip('.so'))
ext_db.close()

def tearDown(self):
super(TestTransitiveClosureManyToMany, self).tearDown()
ext_db.unload_extension(CLOSURE_EXTENSION.rstrip('.so'))
ext_db.close()

def test_manytomany(self):
class Person(BaseExtModel):
name = CharField()

class Relationship(BaseExtModel):
person = ForeignKeyField(Person)
relation = ForeignKeyField(Person, related_name='related_to')

PersonClosure = ClosureTable(
Person,
referencing_class=Relationship,
foreign_key=Relationship.relation,
referencing_key=Relationship.person)

ext_db.drop_tables([Person, Relationship, PersonClosure], safe=True)
ext_db.create_tables([Person, Relationship, PersonClosure])

c = Person.create(name='charlie')
m = Person.create(name='mickey')
h = Person.create(name='huey')
z = Person.create(name='zaizee')
Relationship.create(person=c, relation=h)
Relationship.create(person=c, relation=m)
Relationship.create(person=h, relation=z)
Relationship.create(person=h, relation=m)

def assertPeople(query, expected):
self.assertEqual(sorted([p.name for p in query]), expected)

PC = PersonClosure
assertPeople(PC.descendants(c), [])
assertPeople(PC.ancestors(c), ['huey', 'mickey', 'zaizee'])
assertPeople(PC.siblings(c), ['huey'])

assertPeople(PC.descendants(h), ['charlie'])
assertPeople(PC.ancestors(h), ['mickey', 'zaizee'])
assertPeople(PC.siblings(h), ['charlie'])

assertPeople(PC.descendants(z), ['charlie', 'huey'])
assertPeople(PC.ancestors(z), [])
assertPeople(PC.siblings(z), [])


@skip_if(lambda: not CLOSURE_EXTENSION)
class TestTransitiveClosureIntegration(PeeweeTestCase):
tree = {
Expand Down

0 comments on commit 41c92a7

Please sign in to comment.