diff --git a/modeltree/tree.py b/modeltree/tree.py index 06239b4..f8831ff 100644 --- a/modeltree/tree.py +++ b/modeltree/tree.py @@ -300,8 +300,15 @@ class ModelTree(object): sometimes it is necessary to exclude join paths. For example if there are three possible paths and one should never occur. - A modeltree config takes the `required_routes` and `excluded_routes` - which is a list of routes in the above format. + A modeltree config can have `required_routes` and `excluded_routes` + entries, which are lists of routes in the above format. + + A required route is defined as follows: a join to the specified target + model is only allowed from the specified source model. A single + source model can participate in multiple required routes. + + An excluded route is more obvious: joining from the + specified source model to the specified target model is not allowed. """ def __init__(self, model=None, **kwargs): @@ -337,13 +344,13 @@ def __init__(self, model=None, **kwargs): self.excluded_models = [self.get_model(label, local=False) for label in excluded_models] - # Build the routes are allowed/preferred - self._required_joins, self._required_join_fields = \ - self._build_routes(required_routes) + # Build the routes that are allowed/preferred + self._required_joins = self._build_routes( + required_routes, + allow_redundant_targets=False) # Build the routes that are excluded - self._excluded_joins, self._excluded_join_fields = \ - self._build_routes(excluded_routes) + self._excluded_joins = self._build_routes(excluded_routes) # cache each node relative their models self._nodes = {} @@ -467,11 +474,19 @@ def get_field(self, name, model=None): model = self.root_model return model._meta.get_field_by_name(name)[0] - def _build_routes(self, routes): - "Routes provide a means of specifying JOINs between two tables." + def _build_routes(self, routes, allow_redundant_targets=True): + """Routes provide a means of specifying JOINs between two tables. + + routes - a collection of dicts defining source->target mappings + with optional `field` specifier and `symmetrical` attribute. + + allow_redundant_targets - whether two routes in this collection + are allowed to have the same target - this should NOT + be allowed for required routes. + """ routes = routes or () joins = {} - join_fields = {} + targets_seen = set() for route in routes: if isinstance(route, dict): @@ -507,20 +522,23 @@ def _build_routes(self, routes): if isinstance(field, RelatedObject): field = field.field - # the `joins` hash defines pairs which are explicitly joined - # via the specified field - # if no field is defined, then the join field is implied or - # does not matter. the route is reduced to a straight lookup - joins[target] = source - if symmetrical: - joins[source] = target + if not allow_redundant_targets: + if target in targets_seen: + tpl = ('Model {0} cannot be the target of more than one ' + 'route in this list') + raise ValueError(tpl.format(target_label)) + else: + targets_seen.add(target) - if field is not None: - join_fields[(source, target)] = field - if symmetrical: - join_fields[(target, source)] = field + # The `joins` hash defines pairs which are explicitly joined + # via the specified field. + # If no field is defined, then the join field is implied or + # does not matter; the route is reduced to a straight lookup. + joins[(source, target)] = field + if symmetrical: + joins[(target, source)] = field - return joins, join_fields + return joins def _join_allowed(self, source, target, field=None): """Checks if the join between `source` and `target` via `field` @@ -540,29 +558,31 @@ def _join_allowed(self, source, target, field=None): if target == self.root_model: return False - # Check if the join is excluded via a specific field - if field and join in self._excluded_join_fields: - _field = self._excluded_join_fields[join] - if _field == field: + # Apply excluded joins if any + if join in self._excluded_joins: + _field = self._excluded_joins[join] + if _field: + if _field == field: + return False + else: return False - # Model level.. - elif source == self._excluded_joins.get(target): - return False + # Definition of required join: for the specified target, + # the only join allowed is from the specified source. + # (There can only be one 'required' rule for a given target). - # Check if the join is allowed - if target in self._required_joins: - _source = self._required_joins[target] - if _source != source: - return False - - # If a field is supplied, check to see if the field is allowed - # for this join. - if field: - _field = self._required_join_fields.get(join) - if _field and _field != field: + # Check if the join is allowed by a required rule + for (_source, _target), _field in self._required_joins.items(): + if _target == target: + if _source != source: return False + # If a field is supplied, check to see if the field is allowed + # for this join. + if field: + if _field and _field != field: + return False + return True def _filter_one2one(self, field): diff --git a/tests/cases/core/tests/test_routes.py b/tests/cases/core/tests/test_routes.py index 124776c..57d06c9 100644 --- a/tests/cases/core/tests/test_routes.py +++ b/tests/cases/core/tests/test_routes.py @@ -5,20 +5,18 @@ __all__ = ('RouterTestCase', 'FieldRouterTestCase') - -def compare_paths(self, tree, models, expected_paths): +def compare_paths(self, tree, expected_paths): for i, model in enumerate(self.models): path = [n.model for n in tree._node_path(model)] self.assertEqual(path, expected_paths[i]) -def compare_paths_with_accessor(self, tree, models, expected_paths): +def compare_paths_with_accessor(self, tree, expected_paths): for i, model in enumerate(self.models): path = [(n.model, n.accessor_name) for n in tree._node_path(model)] self.assertEqual(path, expected_paths[i]) - class RouterTestCase(TestCase): def setUp(self): self.models = [A, B, C, D, E, F, G, H, I, J, K] @@ -40,8 +38,7 @@ def test_default(self): [B, D, E, J, K], ] - compare_paths(self, tree, models, expected_paths) - + compare_paths(self, tree, expected_paths) def test_required(self): "D from C rather than B (default)" @@ -55,8 +52,7 @@ def test_required(self): tree = ModelTree(A, **kwargs) - self.assertEqual(tree._required_joins, {D: C}) - self.assertEqual(tree._required_join_fields, {}) + self.assertEqual(tree._required_joins, {(C, D): None}) self.assertTrue(tree._join_allowed(C, D)) self.assertFalse(tree._join_allowed(B, D)) @@ -75,22 +71,27 @@ def test_required(self): [C, D, E, J, K], ] - compare_paths(self, tree, models, expected_paths) + compare_paths(self, tree, expected_paths) def test_excluded(self): "Prevent D from B (go through C)" kwargs = { - 'excluded_routes': [{ - 'target': 'tests.D', - 'source': 'tests.B' - }], + 'excluded_routes': [ + { + 'target': 'tests.D', + 'source': 'tests.B' + }, + # { + # 'target': 'tests.D', + # 'source': 'tests.F', + # } + ], } tree = ModelTree(A, **kwargs) - self.assertEqual(tree._excluded_joins, {D: B}) - self.assertEqual(tree._excluded_join_fields, {}) + self.assertEqual(tree._excluded_joins, {(B, D): None}) self.assertTrue(tree._join_allowed(C, D)) self.assertFalse(tree._join_allowed(B, D)) @@ -109,8 +110,7 @@ def test_excluded(self): [C, D, E, J, K], ] - compare_paths(self, tree, models, expected_paths) - + compare_paths(self, tree, expected_paths) def test_required_long(self): "G from H rather than D or B." @@ -124,8 +124,7 @@ def test_required_long(self): tree = ModelTree(A, **kwargs) - self.assertEqual(tree._required_joins, {G: H}) - self.assertEqual(tree._required_join_fields, {}) + self.assertEqual(tree._required_joins, {(H, G): None}) self.assertTrue(tree._join_allowed(H, G)) self.assertFalse(tree._join_allowed(B, G)) @@ -145,8 +144,7 @@ def test_required_long(self): [B, D, E, J, K], ] - compare_paths(self, tree, models, expected_paths) - + compare_paths(self, tree, expected_paths) def test_required_excluded_combo_long(self): "G from H (rather than D or B), not F from D, not D from B" @@ -167,11 +165,9 @@ def test_required_excluded_combo_long(self): tree = ModelTree(A, **kwargs) - self.assertEqual(tree._required_joins, {G: H}) - self.assertEqual(tree._required_join_fields, {}) + self.assertEqual(tree._required_joins, {(H, G): None}) - self.assertEqual(tree._excluded_joins, {D: B, F: D}) - self.assertEqual(tree._excluded_join_fields, {}) + self.assertEqual(tree._excluded_joins, {(B, D): None, (D, F): None}) self.assertTrue(tree._join_allowed(C, D)) self.assertFalse(tree._join_allowed(B, D)) @@ -195,7 +191,7 @@ def test_required_excluded_combo_long(self): [C, D, E, J, K], ] - compare_paths(self, tree, models, expected_paths) + compare_paths(self, tree, expected_paths) class FieldRouterTestCase(TestCase): @@ -216,11 +212,11 @@ def test_default(self): [(B, 'b_set'), (G, 'g_set'), (H, 'h_set')], [(B, 'b_set'), (G, 'g_set'), (H, 'h_set'), (I, 'i_set')], [(B, 'b_set'), (D, 'd_set'), (E, 'e_set'), (J, 'j_set')], - [(B, 'b_set'), (D, 'd_set'), (E, 'e_set'), (J, 'j_set'), (K, 'k_set')], + [(B, 'b_set'), (D, 'd_set'), (E, 'e_set'), (J, 'j_set'), + (K, 'k_set')], ] - compare_paths_with_accessor(self, tree, models, expected_paths) - + compare_paths_with_accessor(self, tree, expected_paths) def test_required_field(self): kwargs = { @@ -244,7 +240,66 @@ def test_required_field(self): [(B, 'b_set'), (G, 'g_set'), (H, 'h_set')], [(B, 'b_set'), (G, 'g_set'), (H, 'h_set'), (I, 'i_set')], [(B, 'b_set'), (D, 'd_set'), (E, 'e1_set'), (J, 'j_set')], - [(B, 'b_set'), (D, 'd_set'), (E, 'e1_set'), (J, 'j_set'), (K, 'k_set')], + [(B, 'b_set'), (D, 'd_set'), (E, 'e1_set'), (J, 'j_set'), + (K, 'k_set')], ] - compare_paths_with_accessor(self, tree, models, expected_paths) + compare_paths_with_accessor(self, tree, expected_paths) + + def test_excluded_overlapping(self): + "Prevent D from B and D from F (go through C)" + + kwargs = { + 'excluded_routes': [ + { + 'target': 'tests.D', + 'source': 'tests.B' + }, + { + 'target': 'tests.D', + 'source': 'tests.F', + } + ], + } + + tree = ModelTree(A, **kwargs) + + self.assertEqual(tree._excluded_joins, {(B, D): None, + (F, D): None}) + + self.assertTrue(tree._join_allowed(C, D)) + self.assertFalse(tree._join_allowed(B, D)) + self.assertFalse(tree._join_allowed(F, D)) + + expected_paths = [ + [], + [B], + [C], + [C, D], + [C, D, E], + [C, D, F], + [B, G], + [B, G, H], + [B, G, H, I], + [C, D, E, J], + [C, D, E, J, K], + ] + + compare_paths(self, tree, expected_paths) + + def test_required_collision(self): + """Prevent two rules requiring the same target, e.g. + C->D and B->D""" + + kwargs = { + 'required_routes': [{ + 'target': 'tests.D', + 'source': 'tests.C' + }, { + 'target': 'tests.D', + 'source': 'tests.B' + }], + } + + with self.assertRaises(ValueError): + ModelTree(A, **kwargs) diff --git a/tests/models.py b/tests/models.py index f08b4d4..08402ca 100644 --- a/tests/models.py +++ b/tests/models.py @@ -41,6 +41,19 @@ class Meeting(models.Model): # Router Test Models.. no relation to the above models +# The raw model tree looks like this: +# A +# / \ +# C B +# | / \ +# D G +# / \ | +# E F | +# \ / \ | +# J H +# | | +# K I + class A(models.Model): pass