From d476297febc0ef87f1a83242f4cfc2fe2b028701 Mon Sep 17 00:00:00 2001 From: Byron Ruth Date: Fri, 3 Feb 2017 07:45:33 -0500 Subject: [PATCH] Support ForeignKey.to_field Fix #7 Signed-off-by: Byron Ruth --- modeltree/tree.py | 26 +++++++-------- tests/cases/regressions/issue7/__init__.py | 0 tests/cases/regressions/issue7/models.py | 38 +++++++++++++++++++++ tests/cases/regressions/issue7/tests.py | 39 ++++++++++++++++++++++ 4 files changed, 90 insertions(+), 13 deletions(-) create mode 100644 tests/cases/regressions/issue7/__init__.py create mode 100644 tests/cases/regressions/issue7/models.py create mode 100644 tests/cases/regressions/issue7/tests.py diff --git a/modeltree/tree.py b/modeltree/tree.py index 91279ec..9782acc 100644 --- a/modeltree/tree.py +++ b/modeltree/tree.py @@ -132,14 +132,6 @@ def m2m_reverse_field(self): else: return f.field.m2m_reverse_name() - @property - def foreignkey_field_column(self): - f = getattr(self.parent_model, self.accessor_name) - if self.reverse: - return f.related.field.column - else: - return f.field.column - @property def foreignkey_field(self): f = getattr(self.parent_model, self.accessor_name) @@ -183,7 +175,8 @@ def get_joins(self, **kwargs): self.get_connection(None, self.parent.db_table, None, None) joins.append(copy) - # Setup two connections for m2m. + # Setup two connections for m2m. The first is the left model table + # to the "intermediate table" to the right model table. if self.relation == 'manytomany': c1 = self.get_connection( self.parent.db_table, @@ -216,13 +209,20 @@ def get_joins(self, **kwargs): joins.append(copy1) joins.append(copy2) else: + # A reverse direction goes from referred model back to model + # that is referring to it. + if self.reverse: + lhs = self.foreignkey_field.rel.get_related_field().column + rhs = self.foreignkey_field.column + else: + lhs = self.foreignkey_field.column + rhs = self.foreignkey_field.rel.get_related_field().column + c1 = self.get_connection( self.parent.db_table, self.db_table, - self.parent.pk_column if self.reverse - else self.foreignkey_field_column, - self.foreignkey_field_column if self.reverse - else self.pk_column, + lhs, + rhs, ) copy = kwargs.copy() diff --git a/tests/cases/regressions/issue7/__init__.py b/tests/cases/regressions/issue7/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/cases/regressions/issue7/models.py b/tests/cases/regressions/issue7/models.py new file mode 100644 index 0000000..fd40bf7 --- /dev/null +++ b/tests/cases/regressions/issue7/models.py @@ -0,0 +1,38 @@ +from django.db import models + + +class A(models.Model): + study_id = models.CharField(max_length=20, + unique=True, + db_column='study_id') + + class Meta: + db_table = 'a' + + +class B(models.Model): + study_id = models.ForeignKey('A', + to_field='study_id', + unique=True, + db_column='study_id') + + class Meta: + db_table = 'b' + + +class C(models.Model): + id = models.IntegerField(primary_key=True, db_column='c_id') + bs = models.ManyToManyField('B', through='CB') + + class Meta: + db_table = 'c' + + +class CB(models.Model): + c = models.ForeignKey('C', db_column='some_c_id') + sid = models.ForeignKey('B', + to_field='study_id', + db_column='study_id') + + class Meta: + db_table = 'cb' diff --git a/tests/cases/regressions/issue7/tests.py b/tests/cases/regressions/issue7/tests.py new file mode 100644 index 0000000..f203edd --- /dev/null +++ b/tests/cases/regressions/issue7/tests.py @@ -0,0 +1,39 @@ +from django.test import TestCase +from modeltree.tree import trees +from .models import A, B, C + + +class Test(TestCase): + """Wrong foreign key used when constructing joins. + https://github.com/cbmi/modeltree/issues/7 + """ + def test_fk(self): + mt = trees.create(A) + + joins = mt.get_joins(B) + lhs, rhs = joins[1]['connection'][2][0] + self.assertEqual((lhs, rhs), ('study_id', 'study_id')) + + mt = trees.create(B) + joins = mt.get_joins(A) + lhs, rhs = joins[1]['connection'][2][0] + self.assertEqual((lhs, rhs), ('study_id', 'study_id')) + + def test_m2m(self): + mt = trees.create(B) + + joins = mt.get_joins(C) + lhs, rhs = joins[1]['connection'][2][0] + self.assertEqual((lhs, rhs), ('id', 'study_id')) + + lhs, rhs = joins[2]['connection'][2][0] + self.assertEqual((lhs, rhs), ('some_c_id', 'c_id')) + + mt = trees.create(C) + + joins = mt.get_joins(B) + lhs, rhs = joins[1]['connection'][2][0] + self.assertEqual((lhs, rhs), ('c_id', 'some_c_id')) + + lhs, rhs = joins[2]['connection'][2][0] + self.assertEqual((lhs, rhs), ('study_id', 'id'))