Permalink
Browse files

Fixed some issues with self-referential mappings.

git-svn-id: http://tranquil.googlecode.com/svn/trunk@54 66ac46bf-b33b-0410-90a4-43c7d5e8004f
  • Loading branch information...
1 parent a6be190 commit f7f29dfb0a093088921f45edf8aa327689f04556 @davisp davisp committed Oct 9, 2007
Showing with 126 additions and 52 deletions.
  1. +8 −0 proj_name/app_name/models.py
  2. +0 −49 test/models/basic_tests.py
  3. +91 −0 test/models/test.py
  4. +27 −3 tranquil/translator.py
@@ -24,3 +24,11 @@ class Admin:
def __unicode__(self):
return "<Choice '%s'>" % self.choice
+
+class SelfRef(models.Model):
+ parent = models.ForeignKey('self',null=True)
+ name = models.CharField(max_length=50)
+
+class MultiSelfRef(models.Model):
+ name = models.CharField(max_length=50)
+ ref = models.ManyToManyField('self')
View
@@ -1,49 +0,0 @@
-
-import datetime
-import os
-import sys
-import unittest
-
-sys.path.insert( 0, os.path.dirname( __file__ ) + '../' )
-os.environ['DJANGO_SETTINGS_MODULE'] = 'proj_name.settings'
-
-from tranquil import Session
-from tranquil.models.app_name import Poll, Choice
-from tranquil.translator import ORMObject
-
-from proj_name.app_name.alchemy import User
-
-class ModelTest(unittest.TestCase):
- def ORMObject_test(self):
- obj = ORMObject(name='test')
- self.assertEqual( obj.name, 'test' )
-
- def clear(self):
- sess = Session()
- for c in sess.query(Choice):
- sess.delete(c)
- for p in sess.query(Poll):
- sess.delete(p)
- sess.commit()
-
- def model_test(self):
- self.clear()
- sess = Session()
- green = Poll(question='Do you like green eggs and ham?',pub_date=datetime.datetime.today())
- green.choice_set.append( Choice( choice='Yes', votes=0 ) )
- green.choice_set.append( Choice( choice='No', votes=0 ) )
- sess.save( green )
- sess.commit()
- sess = Session()
- polls = sess.query(Poll).all()
- self.assertEqual( len( polls ), 1 )
- self.assertEqual( polls[0].question, 'Do you like green eggs and ham?' )
- self.assertEqual( len( polls[0].choice_set ), 2 )
- self.assertEqual( polls[0].choice_set[0].choice in ['Yes', 'No'], True )
- self.assertEqual( polls[0].choice_set[1].choice in ['Yes', 'No'], True )
-
- def custom_test(self):
- sess = Session()
- me = sess.query(User).filter_by(username='davisp').one()
- self.assertEqual( me.name(), 'username=davisp' )
-
View
@@ -0,0 +1,91 @@
+
+import datetime
+import os
+import sys
+import unittest
+
+sys.path.insert( 0, os.path.dirname( __file__ ) + '../' )
+os.environ['DJANGO_SETTINGS_MODULE'] = 'proj_name.settings'
+
+from tranquil import Session
+from tranquil.models.app_name import Poll, Choice, SelfRef, MultiSelfRef
+from tranquil.translator import ORMObject
+
+from proj_name.app_name.alchemy import User
+
+class ModelTest(unittest.TestCase):
+ def ORMObject_test(self):
+ obj = ORMObject(name='test')
+ self.assertEqual( obj.name, 'test' )
+
+ def clear(self,klass):
+ sess = Session()
+ for k in sess.query(klass).all():
+ sess.delete(k)
+ sess.commit()
+
+ def model_test(self):
+ self.clear(Choice)
+ self.clear(Poll)
+ sess = Session()
+ green = Poll(question='Do you like green eggs and ham?',pub_date=datetime.datetime.today())
+ green.choice_set.append( Choice( choice='Yes', votes=0 ) )
+ green.choice_set.append( Choice( choice='No', votes=0 ) )
+ sess.save( green )
+ sess.commit()
+ sess = Session()
+ polls = sess.query(Poll).all()
+ self.assertEqual( len( polls ), 1 )
+ self.assertEqual( polls[0].question, 'Do you like green eggs and ham?' )
+ self.assertEqual( len( polls[0].choice_set ), 2 )
+ self.assert_( polls[0].choice_set[0].choice in ['Yes', 'No'] )
+ self.assert_( polls[0].choice_set[1].choice in ['Yes', 'No'] )
+
+ def custom_test(self):
+ sess = Session()
+ me = sess.query(User).filter_by(username='davisp').one()
+ self.assertEqual( me.name(), 'username=davisp' )
+
+ def self_ref_test(self):
+ self.clear(SelfRef)
+ sess = Session()
+ root = SelfRef(name='root')
+ sess.save( root )
+ c1 = SelfRef(name='c1',parent=root)
+ sess.save( c1 )
+ sess.commit()
+ sess = Session()
+ root = sess.query(SelfRef).filter(SelfRef.c.id==1).one()
+ self.assertEqual( root.name, 'root' )
+ self.assertEqual( len( root.selfref_set ), 1 )
+ self.assertEqual( root.parent, None )
+ child = root.selfref_set[0]
+ self.assertEqual( child.name, 'c1' )
+ self.assertEqual( child.parent, root )
+ self.assertEqual( child.selfref_set, [] )
+
+ def multi_self_ref_test(self):
+ self.clear(MultiSelfRef)
+ sess = Session()
+ root = MultiSelfRef(name='root')
+ root.ref.append( MultiSelfRef(name='c1') )
+ root.ref.append( MultiSelfRef(name='c2') )
+ root.multiselfref_set.append( MultiSelfRef(name='p1') )
+ sess.save( root )
+ sess.flush()
+ sess = Session()
+ n = sess.query(MultiSelfRef).filter(MultiSelfRef.c.name=='root').one()
+ self.assertEqual( len( n.ref ), 2 )
+ self.assert_( 'c1' in [ t.name for t in n.ref ] )
+ self.assert_( 'c2' in [ t.name for t in n.ref ] )
+ self.assertEqual( len( n.multiselfref_set ), 1 )
+ self.assert_( 'p1' in [ t.name for t in n.multiselfref_set ] )
+ for name in [ 'c1', 'c2']:
+ n = sess.query(MultiSelfRef).filter(MultiSelfRef.c.name==name).one()
+ self.assertEqual( n.ref, [] )
+ self.assertEqual( len( n.multiselfref_set ), 1 )
+ self.assert_( 'root' in [ t.name for t in n.multiselfref_set ] )
+ n = sess.query(MultiSelfRef).filter(MultiSelfRef.c.name=='p1').one()
+ self.assertEqual( len( n.ref ), 1 )
+ self.assert_( 'root' in [ t.name for t in n.ref ] )
+ self.assertEqual( n.multiselfref_set, [] )
View
@@ -1,4 +1,5 @@
+from pprint import pprint
from inspect import getmembers
from django.conf import settings
@@ -37,12 +38,14 @@ def __init__(self,model,field):
self.field = field
def get_backref(self):
+ ret = ''
if self.field.rel.multiple:
if getattr( self.field.rel, 'symmetrical', False ) and self.model == self.field.rel.to:
- return None
- return self.field.rel.related_name or ( self.model._meta.object_name.lower() + '_set' )
+ ret = None
+ ret = self.field.rel.related_name or ( self.model._meta.object_name.lower() + '_set' )
else:
- return self.field.rel.related_name or ( self.model._meta.object_name.lower() )
+ ret = self.field.rel.related_name or ( self.model._meta.object_name.lower() )
+ return ret
def get_m2m_table(self):
if self.field.rel.multiple:
@@ -56,13 +59,34 @@ def add_fkey(self,kwargs,column):
return kwargs
def props(self,tables,mt_map):
+ #print '\nMODEL: %s' % self.model.__name__
kwargs = {}
fn = getattr( self, self.field.__class__.__name__ )
kwargs = fn( kwargs, tables, mt_map )
brefargs = kwargs.copy()
if brefargs.get( 'secondary' ) is not None:
del brefargs['secondary']
+ #pprint( brefargs )
+ brefargs = self.switch_joins( brefargs )
kwargs['backref'] = backref( self.get_backref(), **brefargs )
+ kwargs = self.add_remote_side( kwargs, mt_map )
+ #pprint( kwargs )
+ return kwargs
+
+ def switch_joins(self,kwargs):
+ if self.field.__class__.__name__ != 'ManyToManyField':
+ return kwargs
+ prim = kwargs['primaryjoin']
+ kwargs['primaryjoin'] = kwargs['secondaryjoin']
+ kwargs['secondaryjoin'] = prim
+ return kwargs
+
+ def add_remote_side(self,kwargs,mt_map):
+ if self.model != self.field.rel.to or self.field.__class__.__name__ == 'ManyToManyField':
+ return kwargs
+ to = mt_map[self.field.rel.to]
+ tcol = getattr( to.c, self.field.rel.field_name )
+ kwargs['remote_side'] = [ tcol ]
return kwargs
def add_primary_join(self,kwargs,mt_map):

0 comments on commit f7f29df

Please sign in to comment.