Skip to content

Commit

Permalink
Fixed some issues with self-referential mappings.
Browse files Browse the repository at this point in the history
git-svn-id: http://tranquil.googlecode.com/svn/trunk@54 66ac46bf-b33b-0410-90a4-43c7d5e8004f
  • Loading branch information
davisp committed Oct 9, 2007
1 parent a6be190 commit f7f29df
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 52 deletions.
8 changes: 8 additions & 0 deletions proj_name/app_name/models.py
Expand Up @@ -24,3 +24,11 @@ class Admin:


def __unicode__(self): def __unicode__(self):
return "<Choice '%s'>" % self.choice return "<Choice '%s'>" % self.choice

class SelfRef(models.Model):
parent = models.ForeignKey('self',null=True)
name = models.CharField(max_length=50)

class MultiSelfRef(models.Model):
name = models.CharField(max_length=50)
ref = models.ManyToManyField('self')
49 changes: 0 additions & 49 deletions test/models/basic_tests.py

This file was deleted.

91 changes: 91 additions & 0 deletions test/models/test.py
@@ -0,0 +1,91 @@

import datetime
import os
import sys
import unittest

sys.path.insert( 0, os.path.dirname( __file__ ) + '../' )
os.environ['DJANGO_SETTINGS_MODULE'] = 'proj_name.settings'

from tranquil import Session
from tranquil.models.app_name import Poll, Choice, SelfRef, MultiSelfRef
from tranquil.translator import ORMObject

from proj_name.app_name.alchemy import User

class ModelTest(unittest.TestCase):
def ORMObject_test(self):
obj = ORMObject(name='test')
self.assertEqual( obj.name, 'test' )

def clear(self,klass):
sess = Session()
for k in sess.query(klass).all():
sess.delete(k)
sess.commit()

def model_test(self):
self.clear(Choice)
self.clear(Poll)
sess = Session()
green = Poll(question='Do you like green eggs and ham?',pub_date=datetime.datetime.today())
green.choice_set.append( Choice( choice='Yes', votes=0 ) )
green.choice_set.append( Choice( choice='No', votes=0 ) )
sess.save( green )
sess.commit()
sess = Session()
polls = sess.query(Poll).all()
self.assertEqual( len( polls ), 1 )
self.assertEqual( polls[0].question, 'Do you like green eggs and ham?' )
self.assertEqual( len( polls[0].choice_set ), 2 )
self.assert_( polls[0].choice_set[0].choice in ['Yes', 'No'] )
self.assert_( polls[0].choice_set[1].choice in ['Yes', 'No'] )

def custom_test(self):
sess = Session()
me = sess.query(User).filter_by(username='davisp').one()
self.assertEqual( me.name(), 'username=davisp' )

def self_ref_test(self):
self.clear(SelfRef)
sess = Session()
root = SelfRef(name='root')
sess.save( root )
c1 = SelfRef(name='c1',parent=root)
sess.save( c1 )
sess.commit()
sess = Session()
root = sess.query(SelfRef).filter(SelfRef.c.id==1).one()
self.assertEqual( root.name, 'root' )
self.assertEqual( len( root.selfref_set ), 1 )
self.assertEqual( root.parent, None )
child = root.selfref_set[0]
self.assertEqual( child.name, 'c1' )
self.assertEqual( child.parent, root )
self.assertEqual( child.selfref_set, [] )

def multi_self_ref_test(self):
self.clear(MultiSelfRef)
sess = Session()
root = MultiSelfRef(name='root')
root.ref.append( MultiSelfRef(name='c1') )
root.ref.append( MultiSelfRef(name='c2') )
root.multiselfref_set.append( MultiSelfRef(name='p1') )
sess.save( root )
sess.flush()
sess = Session()
n = sess.query(MultiSelfRef).filter(MultiSelfRef.c.name=='root').one()
self.assertEqual( len( n.ref ), 2 )
self.assert_( 'c1' in [ t.name for t in n.ref ] )
self.assert_( 'c2' in [ t.name for t in n.ref ] )
self.assertEqual( len( n.multiselfref_set ), 1 )
self.assert_( 'p1' in [ t.name for t in n.multiselfref_set ] )
for name in [ 'c1', 'c2']:
n = sess.query(MultiSelfRef).filter(MultiSelfRef.c.name==name).one()
self.assertEqual( n.ref, [] )
self.assertEqual( len( n.multiselfref_set ), 1 )
self.assert_( 'root' in [ t.name for t in n.multiselfref_set ] )
n = sess.query(MultiSelfRef).filter(MultiSelfRef.c.name=='p1').one()
self.assertEqual( len( n.ref ), 1 )
self.assert_( 'root' in [ t.name for t in n.ref ] )
self.assertEqual( n.multiselfref_set, [] )
30 changes: 27 additions & 3 deletions tranquil/translator.py
@@ -1,4 +1,5 @@


from pprint import pprint
from inspect import getmembers from inspect import getmembers


from django.conf import settings from django.conf import settings
Expand Down Expand Up @@ -37,12 +38,14 @@ def __init__(self,model,field):
self.field = field self.field = field


def get_backref(self): def get_backref(self):
ret = ''
if self.field.rel.multiple: if self.field.rel.multiple:
if getattr( self.field.rel, 'symmetrical', False ) and self.model == self.field.rel.to: if getattr( self.field.rel, 'symmetrical', False ) and self.model == self.field.rel.to:
return None ret = None
return self.field.rel.related_name or ( self.model._meta.object_name.lower() + '_set' ) ret = self.field.rel.related_name or ( self.model._meta.object_name.lower() + '_set' )
else: else:
return self.field.rel.related_name or ( self.model._meta.object_name.lower() ) ret = self.field.rel.related_name or ( self.model._meta.object_name.lower() )
return ret


def get_m2m_table(self): def get_m2m_table(self):
if self.field.rel.multiple: if self.field.rel.multiple:
Expand All @@ -56,13 +59,34 @@ def add_fkey(self,kwargs,column):
return kwargs return kwargs


def props(self,tables,mt_map): def props(self,tables,mt_map):
#print '\nMODEL: %s' % self.model.__name__
kwargs = {} kwargs = {}
fn = getattr( self, self.field.__class__.__name__ ) fn = getattr( self, self.field.__class__.__name__ )
kwargs = fn( kwargs, tables, mt_map ) kwargs = fn( kwargs, tables, mt_map )
brefargs = kwargs.copy() brefargs = kwargs.copy()
if brefargs.get( 'secondary' ) is not None: if brefargs.get( 'secondary' ) is not None:
del brefargs['secondary'] del brefargs['secondary']
#pprint( brefargs )
brefargs = self.switch_joins( brefargs )
kwargs['backref'] = backref( self.get_backref(), **brefargs ) kwargs['backref'] = backref( self.get_backref(), **brefargs )
kwargs = self.add_remote_side( kwargs, mt_map )
#pprint( kwargs )
return kwargs

def switch_joins(self,kwargs):
if self.field.__class__.__name__ != 'ManyToManyField':
return kwargs
prim = kwargs['primaryjoin']
kwargs['primaryjoin'] = kwargs['secondaryjoin']
kwargs['secondaryjoin'] = prim
return kwargs

def add_remote_side(self,kwargs,mt_map):
if self.model != self.field.rel.to or self.field.__class__.__name__ == 'ManyToManyField':
return kwargs
to = mt_map[self.field.rel.to]
tcol = getattr( to.c, self.field.rel.field_name )
kwargs['remote_side'] = [ tcol ]
return kwargs return kwargs


def add_primary_join(self,kwargs,mt_map): def add_primary_join(self,kwargs,mt_map):
Expand Down

0 comments on commit f7f29df

Please sign in to comment.