From 8d8ccb5ae6cb9d1b34b15606206e1fc1d744b539 Mon Sep 17 00:00:00 2001 From: Kevin Murphy Date: Tue, 4 Nov 2014 15:43:54 -0500 Subject: [PATCH] Fix bugs in required and excluded routes It is now an error to specify more than one required route with the same target (ValueError is raised). It is now possible to define multiple excluded routes involving the same target model. Previously, the last defined exclusion route for a target model would silently replace any other such routes. Signed-off-by: Kevin Murphy --- modeltree/tree.py | 100 +++++++++++++--------- tests/cases/core/tests/test_routes.py | 117 +++++++++++++++++++------- tests/models.py | 13 +++ 3 files changed, 159 insertions(+), 71 deletions(-) diff --git a/modeltree/tree.py b/modeltree/tree.py index 06239b4..f8831ff 100644 --- a/modeltree/tree.py +++ b/modeltree/tree.py @@ -300,8 +300,15 @@ class ModelTree(object): sometimes it is necessary to exclude join paths. For example if there are three possible paths and one should never occur. - A modeltree config takes the `required_routes` and `excluded_routes` - which is a list of routes in the above format. + A modeltree config can have `required_routes` and `excluded_routes` + entries, which are lists of routes in the above format. + + A required route is defined as follows: a join to the specified target + model is only allowed from the specified source model. A single + source model can participate in multiple required routes. + + An excluded route is more obvious: joining from the + specified source model to the specified target model is not allowed. """ def __init__(self, model=None, **kwargs): @@ -337,13 +344,13 @@ def __init__(self, model=None, **kwargs): self.excluded_models = [self.get_model(label, local=False) for label in excluded_models] - # Build the routes are allowed/preferred - self._required_joins, self._required_join_fields = \ - self._build_routes(required_routes) + # Build the routes that are allowed/preferred + self._required_joins = self._build_routes( + required_routes, + allow_redundant_targets=False) # Build the routes that are excluded - self._excluded_joins, self._excluded_join_fields = \ - self._build_routes(excluded_routes) + self._excluded_joins = self._build_routes(excluded_routes) # cache each node relative their models self._nodes = {} @@ -467,11 +474,19 @@ def get_field(self, name, model=None): model = self.root_model return model._meta.get_field_by_name(name)[0] - def _build_routes(self, routes): - "Routes provide a means of specifying JOINs between two tables." + def _build_routes(self, routes, allow_redundant_targets=True): + """Routes provide a means of specifying JOINs between two tables. + + routes - a collection of dicts defining source->target mappings + with optional `field` specifier and `symmetrical` attribute. + + allow_redundant_targets - whether two routes in this collection + are allowed to have the same target - this should NOT + be allowed for required routes. + """ routes = routes or () joins = {} - join_fields = {} + targets_seen = set() for route in routes: if isinstance(route, dict): @@ -507,20 +522,23 @@ def _build_routes(self, routes): if isinstance(field, RelatedObject): field = field.field - # the `joins` hash defines pairs which are explicitly joined - # via the specified field - # if no field is defined, then the join field is implied or - # does not matter. the route is reduced to a straight lookup - joins[target] = source - if symmetrical: - joins[source] = target + if not allow_redundant_targets: + if target in targets_seen: + tpl = ('Model {0} cannot be the target of more than one ' + 'route in this list') + raise ValueError(tpl.format(target_label)) + else: + targets_seen.add(target) - if field is not None: - join_fields[(source, target)] = field - if symmetrical: - join_fields[(target, source)] = field + # The `joins` hash defines pairs which are explicitly joined + # via the specified field. + # If no field is defined, then the join field is implied or + # does not matter; the route is reduced to a straight lookup. + joins[(source, target)] = field + if symmetrical: + joins[(target, source)] = field - return joins, join_fields + return joins def _join_allowed(self, source, target, field=None): """Checks if the join between `source` and `target` via `field` @@ -540,29 +558,31 @@ def _join_allowed(self, source, target, field=None): if target == self.root_model: return False - # Check if the join is excluded via a specific field - if field and join in self._excluded_join_fields: - _field = self._excluded_join_fields[join] - if _field == field: + # Apply excluded joins if any + if join in self._excluded_joins: + _field = self._excluded_joins[join] + if _field: + if _field == field: + return False + else: return False - # Model level.. - elif source == self._excluded_joins.get(target): - return False + # Definition of required join: for the specified target, + # the only join allowed is from the specified source. + # (There can only be one 'required' rule for a given target). - # Check if the join is allowed - if target in self._required_joins: - _source = self._required_joins[target] - if _source != source: - return False - - # If a field is supplied, check to see if the field is allowed - # for this join. - if field: - _field = self._required_join_fields.get(join) - if _field and _field != field: + # Check if the join is allowed by a required rule + for (_source, _target), _field in self._required_joins.items(): + if _target == target: + if _source != source: return False + # If a field is supplied, check to see if the field is allowed + # for this join. + if field: + if _field and _field != field: + return False + return True def _filter_one2one(self, field): diff --git a/tests/cases/core/tests/test_routes.py b/tests/cases/core/tests/test_routes.py index 124776c..57d06c9 100644 --- a/tests/cases/core/tests/test_routes.py +++ b/tests/cases/core/tests/test_routes.py @@ -5,20 +5,18 @@ __all__ = ('RouterTestCase', 'FieldRouterTestCase') - -def compare_paths(self, tree, models, expected_paths): +def compare_paths(self, tree, expected_paths): for i, model in enumerate(self.models): path = [n.model for n in tree._node_path(model)] self.assertEqual(path, expected_paths[i]) -def compare_paths_with_accessor(self, tree, models, expected_paths): +def compare_paths_with_accessor(self, tree, expected_paths): for i, model in enumerate(self.models): path = [(n.model, n.accessor_name) for n in tree._node_path(model)] self.assertEqual(path, expected_paths[i]) - class RouterTestCase(TestCase): def setUp(self): self.models = [A, B, C, D, E, F, G, H, I, J, K] @@ -40,8 +38,7 @@ def test_default(self): [B, D, E, J, K], ] - compare_paths(self, tree, models, expected_paths) - + compare_paths(self, tree, expected_paths) def test_required(self): "D from C rather than B (default)" @@ -55,8 +52,7 @@ def test_required(self): tree = ModelTree(A, **kwargs) - self.assertEqual(tree._required_joins, {D: C}) - self.assertEqual(tree._required_join_fields, {}) + self.assertEqual(tree._required_joins, {(C, D): None}) self.assertTrue(tree._join_allowed(C, D)) self.assertFalse(tree._join_allowed(B, D)) @@ -75,22 +71,27 @@ def test_required(self): [C, D, E, J, K], ] - compare_paths(self, tree, models, expected_paths) + compare_paths(self, tree, expected_paths) def test_excluded(self): "Prevent D from B (go through C)" kwargs = { - 'excluded_routes': [{ - 'target': 'tests.D', - 'source': 'tests.B' - }], + 'excluded_routes': [ + { + 'target': 'tests.D', + 'source': 'tests.B' + }, + # { + # 'target': 'tests.D', + # 'source': 'tests.F', + # } + ], } tree = ModelTree(A, **kwargs) - self.assertEqual(tree._excluded_joins, {D: B}) - self.assertEqual(tree._excluded_join_fields, {}) + self.assertEqual(tree._excluded_joins, {(B, D): None}) self.assertTrue(tree._join_allowed(C, D)) self.assertFalse(tree._join_allowed(B, D)) @@ -109,8 +110,7 @@ def test_excluded(self): [C, D, E, J, K], ] - compare_paths(self, tree, models, expected_paths) - + compare_paths(self, tree, expected_paths) def test_required_long(self): "G from H rather than D or B." @@ -124,8 +124,7 @@ def test_required_long(self): tree = ModelTree(A, **kwargs) - self.assertEqual(tree._required_joins, {G: H}) - self.assertEqual(tree._required_join_fields, {}) + self.assertEqual(tree._required_joins, {(H, G): None}) self.assertTrue(tree._join_allowed(H, G)) self.assertFalse(tree._join_allowed(B, G)) @@ -145,8 +144,7 @@ def test_required_long(self): [B, D, E, J, K], ] - compare_paths(self, tree, models, expected_paths) - + compare_paths(self, tree, expected_paths) def test_required_excluded_combo_long(self): "G from H (rather than D or B), not F from D, not D from B" @@ -167,11 +165,9 @@ def test_required_excluded_combo_long(self): tree = ModelTree(A, **kwargs) - self.assertEqual(tree._required_joins, {G: H}) - self.assertEqual(tree._required_join_fields, {}) + self.assertEqual(tree._required_joins, {(H, G): None}) - self.assertEqual(tree._excluded_joins, {D: B, F: D}) - self.assertEqual(tree._excluded_join_fields, {}) + self.assertEqual(tree._excluded_joins, {(B, D): None, (D, F): None}) self.assertTrue(tree._join_allowed(C, D)) self.assertFalse(tree._join_allowed(B, D)) @@ -195,7 +191,7 @@ def test_required_excluded_combo_long(self): [C, D, E, J, K], ] - compare_paths(self, tree, models, expected_paths) + compare_paths(self, tree, expected_paths) class FieldRouterTestCase(TestCase): @@ -216,11 +212,11 @@ def test_default(self): [(B, 'b_set'), (G, 'g_set'), (H, 'h_set')], [(B, 'b_set'), (G, 'g_set'), (H, 'h_set'), (I, 'i_set')], [(B, 'b_set'), (D, 'd_set'), (E, 'e_set'), (J, 'j_set')], - [(B, 'b_set'), (D, 'd_set'), (E, 'e_set'), (J, 'j_set'), (K, 'k_set')], + [(B, 'b_set'), (D, 'd_set'), (E, 'e_set'), (J, 'j_set'), + (K, 'k_set')], ] - compare_paths_with_accessor(self, tree, models, expected_paths) - + compare_paths_with_accessor(self, tree, expected_paths) def test_required_field(self): kwargs = { @@ -244,7 +240,66 @@ def test_required_field(self): [(B, 'b_set'), (G, 'g_set'), (H, 'h_set')], [(B, 'b_set'), (G, 'g_set'), (H, 'h_set'), (I, 'i_set')], [(B, 'b_set'), (D, 'd_set'), (E, 'e1_set'), (J, 'j_set')], - [(B, 'b_set'), (D, 'd_set'), (E, 'e1_set'), (J, 'j_set'), (K, 'k_set')], + [(B, 'b_set'), (D, 'd_set'), (E, 'e1_set'), (J, 'j_set'), + (K, 'k_set')], ] - compare_paths_with_accessor(self, tree, models, expected_paths) + compare_paths_with_accessor(self, tree, expected_paths) + + def test_excluded_overlapping(self): + "Prevent D from B and D from F (go through C)" + + kwargs = { + 'excluded_routes': [ + { + 'target': 'tests.D', + 'source': 'tests.B' + }, + { + 'target': 'tests.D', + 'source': 'tests.F', + } + ], + } + + tree = ModelTree(A, **kwargs) + + self.assertEqual(tree._excluded_joins, {(B, D): None, + (F, D): None}) + + self.assertTrue(tree._join_allowed(C, D)) + self.assertFalse(tree._join_allowed(B, D)) + self.assertFalse(tree._join_allowed(F, D)) + + expected_paths = [ + [], + [B], + [C], + [C, D], + [C, D, E], + [C, D, F], + [B, G], + [B, G, H], + [B, G, H, I], + [C, D, E, J], + [C, D, E, J, K], + ] + + compare_paths(self, tree, expected_paths) + + def test_required_collision(self): + """Prevent two rules requiring the same target, e.g. + C->D and B->D""" + + kwargs = { + 'required_routes': [{ + 'target': 'tests.D', + 'source': 'tests.C' + }, { + 'target': 'tests.D', + 'source': 'tests.B' + }], + } + + with self.assertRaises(ValueError): + ModelTree(A, **kwargs) diff --git a/tests/models.py b/tests/models.py index f08b4d4..08402ca 100644 --- a/tests/models.py +++ b/tests/models.py @@ -41,6 +41,19 @@ class Meeting(models.Model): # Router Test Models.. no relation to the above models +# The raw model tree looks like this: +# A +# / \ +# C B +# | / \ +# D G +# / \ | +# E F | +# \ / \ | +# J H +# | | +# K I + class A(models.Model): pass