Skip to content

Commit

Permalink
Just a few changes related to #1018.
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles Leifer committed Jul 20, 2016
1 parent f93cee4 commit aad4eb5
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 34 deletions.
25 changes: 13 additions & 12 deletions peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,13 @@ def __setattr__(self, attr, value):
return super(Proxy, self).__setattr__(attr, value)

class DeferredRelation(object):
_unresolved_deferred_relations = set()
_unresolved = set()

def __init__(self, rel_model_name=None):
if rel_model_name is not None:
self._rel_model_name = rel_model_name
self._unresolved_deferred_relations.add(self)
self._rel_model_name = rel_model_name.lower()
self._unresolved.add(self)

def set_field(self, model_class, field, name):
self.model_class = model_class
self.field = field
Expand All @@ -440,13 +440,14 @@ def set_field(self, model_class, field, name):
def set_model(self, rel_model):
self.field.rel_model = rel_model
self.field.add_to_class(self.model_class, self.name)

@staticmethod
def _resolve_unresolved_deferred_relations(cls):
for deferred_relation in list(DeferredRelation._unresolved_deferred_relations):
if deferred_relation._rel_model_name == cls.__name__:
deferred_relation.set_model(cls)
DeferredRelation._unresolved_deferred_relations.discard(deferred_relation)
def resolve(model_cls):
unresolved = list(DeferredRelation._unresolved)
for dr in unresolved:
if dr._rel_model_name == model_cls.__name__.lower():
dr.set_model(model_cls)
DeferredRelation._unresolved.discard(dr)


class _CDescriptor(object):
Expand Down Expand Up @@ -4706,7 +4707,7 @@ def __new__(cls, name, bases, attrs):
if hasattr(cls, 'validate_model'):
cls.validate_model()

DeferredRelation._resolve_unresolved_deferred_relations(cls)
DeferredRelation.resolve(cls)

return cls

Expand Down
12 changes: 0 additions & 12 deletions playhouse/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,16 +295,6 @@ class Snippet(TestModel):
SnippetDeferred.set_model(Snippet)


class Language2(TestModel):
name = CharField()
selected_snippet = ForeignKeyField(DeferredRelation('Snippet2'), null=True)


class Snippet2(TestModel):
code = TextField()
language = ForeignKeyField(Language2, related_name='snippets')


class _UpperField(CharField):
def python_value(self, value):
return value.upper() if value else value
Expand Down Expand Up @@ -455,9 +445,7 @@ class NoteFlagNullable(TestModel):
TagPostThrough,
TagPostThroughAlt,
Language,
Language2,
Snippet,
Snippet2,
Manufacturer,
CompositeKeyModel,
UserThing,
Expand Down
28 changes: 18 additions & 10 deletions playhouse/tests/test_keys.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from peewee import DeferredRelation
from peewee import Model
from peewee import SqliteDatabase
from playhouse.tests.base import compiler
from playhouse.tests.base import database_initializer
Expand Down Expand Up @@ -334,27 +336,33 @@ def setUp(self):
Language.drop_table(True)
Language.create_table()
Snippet.create_table()
Snippet2.drop_table(True)
Language2.drop_table(True)
Language2.create_table()
Snippet2.create_table()

def tearDown(self):
super(TestDeferredForeignKey, self).tearDown()
Snippet.drop_table(True)
Language.drop_table(True)
Snippet2.drop_table(True)
Language2.drop_table(True)

def test_field_definitions(self):
self.assertEqual(Snippet._meta.fields['language'].rel_model, Language)
self.assertEqual(Language._meta.fields['selected_snippet'].rel_model,
Snippet)

def test_field_definitions2(self):
self.assertEqual(Snippet2._meta.fields['language'].rel_model, Language2)
self.assertEqual(Language2._meta.fields['selected_snippet'].rel_model,
Snippet2)
def test_deferred_relation_resolution(self):
orig = len(DeferredRelation._unresolved)

class CircularRef1(Model):
circ_ref2 = ForeignKeyField(
DeferredRelation('circularref2'),
null=True)

self.assertEqual(len(DeferredRelation._unresolved), orig + 1)

class CircularRef2(Model):
circ_ref1 = ForeignKeyField(CircularRef1, null=True)

self.assertEqual(CircularRef1.circ_ref2.rel_model, CircularRef2)
self.assertEqual(CircularRef2.circ_ref1.rel_model, CircularRef1)
self.assertEqual(len(DeferredRelation._unresolved), orig)

def test_create_table_query(self):
query, params = compiler.create_table(Snippet)
Expand Down

0 comments on commit aad4eb5

Please sign in to comment.