Skip to content
This repository has been archived by the owner on Jan 18, 2020. It is now read-only.

Commit

Permalink
Support ForeignKey.to_field
Browse files Browse the repository at this point in the history
Fix #7

Signed-off-by: Byron Ruth <b@devel.io>
  • Loading branch information
bruth committed Feb 3, 2017
1 parent 5c8edf6 commit d476297
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 13 deletions.
26 changes: 13 additions & 13 deletions modeltree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,6 @@ def m2m_reverse_field(self):
else:
return f.field.m2m_reverse_name()

@property
def foreignkey_field_column(self):
f = getattr(self.parent_model, self.accessor_name)
if self.reverse:
return f.related.field.column
else:
return f.field.column

@property
def foreignkey_field(self):
f = getattr(self.parent_model, self.accessor_name)
Expand Down Expand Up @@ -183,7 +175,8 @@ def get_joins(self, **kwargs):
self.get_connection(None, self.parent.db_table, None, None)
joins.append(copy)

# Setup two connections for m2m.
# Setup two connections for m2m. The first is the left model table
# to the "intermediate table" to the right model table.
if self.relation == 'manytomany':
c1 = self.get_connection(
self.parent.db_table,
Expand Down Expand Up @@ -216,13 +209,20 @@ def get_joins(self, **kwargs):
joins.append(copy1)
joins.append(copy2)
else:
# A reverse direction goes from referred model back to model
# that is referring to it.
if self.reverse:
lhs = self.foreignkey_field.rel.get_related_field().column
rhs = self.foreignkey_field.column
else:
lhs = self.foreignkey_field.column
rhs = self.foreignkey_field.rel.get_related_field().column

c1 = self.get_connection(
self.parent.db_table,
self.db_table,
self.parent.pk_column if self.reverse
else self.foreignkey_field_column,
self.foreignkey_field_column if self.reverse
else self.pk_column,
lhs,
rhs,
)

copy = kwargs.copy()
Expand Down
Empty file.
38 changes: 38 additions & 0 deletions tests/cases/regressions/issue7/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from django.db import models


class A(models.Model):
study_id = models.CharField(max_length=20,
unique=True,
db_column='study_id')

class Meta:
db_table = 'a'


class B(models.Model):
study_id = models.ForeignKey('A',
to_field='study_id',
unique=True,
db_column='study_id')

class Meta:
db_table = 'b'


class C(models.Model):
id = models.IntegerField(primary_key=True, db_column='c_id')
bs = models.ManyToManyField('B', through='CB')

class Meta:
db_table = 'c'


class CB(models.Model):
c = models.ForeignKey('C', db_column='some_c_id')
sid = models.ForeignKey('B',
to_field='study_id',
db_column='study_id')

class Meta:
db_table = 'cb'
39 changes: 39 additions & 0 deletions tests/cases/regressions/issue7/tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from django.test import TestCase
from modeltree.tree import trees
from .models import A, B, C


class Test(TestCase):
"""Wrong foreign key used when constructing joins.
https://github.com/cbmi/modeltree/issues/7
"""
def test_fk(self):
mt = trees.create(A)

joins = mt.get_joins(B)
lhs, rhs = joins[1]['connection'][2][0]
self.assertEqual((lhs, rhs), ('study_id', 'study_id'))

mt = trees.create(B)
joins = mt.get_joins(A)
lhs, rhs = joins[1]['connection'][2][0]
self.assertEqual((lhs, rhs), ('study_id', 'study_id'))

def test_m2m(self):
mt = trees.create(B)

joins = mt.get_joins(C)
lhs, rhs = joins[1]['connection'][2][0]
self.assertEqual((lhs, rhs), ('id', 'study_id'))

lhs, rhs = joins[2]['connection'][2][0]
self.assertEqual((lhs, rhs), ('some_c_id', 'c_id'))

mt = trees.create(C)

joins = mt.get_joins(B)
lhs, rhs = joins[1]['connection'][2][0]
self.assertEqual((lhs, rhs), ('c_id', 'some_c_id'))

lhs, rhs = joins[2]['connection'][2][0]
self.assertEqual((lhs, rhs), ('study_id', 'id'))

0 comments on commit d476297

Please sign in to comment.