Skip to content
This repository has been archived by the owner on Jan 18, 2020. It is now read-only.

Commit

Permalink
Merge 8d8ccb5 into 2982b10
Browse files Browse the repository at this point in the history
  • Loading branch information
murphyke committed Nov 5, 2014
2 parents 2982b10 + 8d8ccb5 commit f292ae8
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 71 deletions.
100 changes: 60 additions & 40 deletions modeltree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,15 @@ class ModelTree(object):
sometimes it is necessary to exclude join paths. For example if there
are three possible paths and one should never occur.
A modeltree config takes the `required_routes` and `excluded_routes`
which is a list of routes in the above format.
A modeltree config can have `required_routes` and `excluded_routes`
entries, which are lists of routes in the above format.
A required route is defined as follows: a join to the specified target
model is only allowed from the specified source model. A single
source model can participate in multiple required routes.
An excluded route is more obvious: joining from the
specified source model to the specified target model is not allowed.
"""
def __init__(self, model=None, **kwargs):
Expand Down Expand Up @@ -337,13 +344,13 @@ def __init__(self, model=None, **kwargs):
self.excluded_models = [self.get_model(label, local=False)
for label in excluded_models]

# Build the routes are allowed/preferred
self._required_joins, self._required_join_fields = \
self._build_routes(required_routes)
# Build the routes that are allowed/preferred
self._required_joins = self._build_routes(
required_routes,
allow_redundant_targets=False)

# Build the routes that are excluded
self._excluded_joins, self._excluded_join_fields = \
self._build_routes(excluded_routes)
self._excluded_joins = self._build_routes(excluded_routes)

# cache each node relative their models
self._nodes = {}
Expand Down Expand Up @@ -467,11 +474,19 @@ def get_field(self, name, model=None):
model = self.root_model
return model._meta.get_field_by_name(name)[0]

def _build_routes(self, routes):
"Routes provide a means of specifying JOINs between two tables."
def _build_routes(self, routes, allow_redundant_targets=True):
"""Routes provide a means of specifying JOINs between two tables.
routes - a collection of dicts defining source->target mappings
with optional `field` specifier and `symmetrical` attribute.
allow_redundant_targets - whether two routes in this collection
are allowed to have the same target - this should NOT
be allowed for required routes.
"""
routes = routes or ()
joins = {}
join_fields = {}
targets_seen = set()

for route in routes:
if isinstance(route, dict):
Expand Down Expand Up @@ -507,20 +522,23 @@ def _build_routes(self, routes):
if isinstance(field, RelatedObject):
field = field.field

# the `joins` hash defines pairs which are explicitly joined
# via the specified field
# if no field is defined, then the join field is implied or
# does not matter. the route is reduced to a straight lookup
joins[target] = source
if symmetrical:
joins[source] = target
if not allow_redundant_targets:
if target in targets_seen:
tpl = ('Model {0} cannot be the target of more than one '
'route in this list')
raise ValueError(tpl.format(target_label))
else:
targets_seen.add(target)

if field is not None:
join_fields[(source, target)] = field
if symmetrical:
join_fields[(target, source)] = field
# The `joins` hash defines pairs which are explicitly joined
# via the specified field.
# If no field is defined, then the join field is implied or
# does not matter; the route is reduced to a straight lookup.
joins[(source, target)] = field
if symmetrical:
joins[(target, source)] = field

return joins, join_fields
return joins

def _join_allowed(self, source, target, field=None):
"""Checks if the join between `source` and `target` via `field`
Expand All @@ -540,29 +558,31 @@ def _join_allowed(self, source, target, field=None):
if target == self.root_model:
return False

# Check if the join is excluded via a specific field
if field and join in self._excluded_join_fields:
_field = self._excluded_join_fields[join]
if _field == field:
# Apply excluded joins if any
if join in self._excluded_joins:
_field = self._excluded_joins[join]
if _field:
if _field == field:
return False
else:
return False

# Model level..
elif source == self._excluded_joins.get(target):
return False
# Definition of required join: for the specified target,
# the only join allowed is from the specified source.
# (There can only be one 'required' rule for a given target).

# Check if the join is allowed
if target in self._required_joins:
_source = self._required_joins[target]
if _source != source:
return False

# If a field is supplied, check to see if the field is allowed
# for this join.
if field:
_field = self._required_join_fields.get(join)
if _field and _field != field:
# Check if the join is allowed by a required rule
for (_source, _target), _field in self._required_joins.items():
if _target == target:
if _source != source:
return False

# If a field is supplied, check to see if the field is allowed
# for this join.
if field:
if _field and _field != field:
return False

return True

def _filter_one2one(self, field):
Expand Down
117 changes: 86 additions & 31 deletions tests/cases/core/tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,18 @@
__all__ = ('RouterTestCase', 'FieldRouterTestCase')



def compare_paths(self, tree, models, expected_paths):
def compare_paths(self, tree, expected_paths):
for i, model in enumerate(self.models):
path = [n.model for n in tree._node_path(model)]
self.assertEqual(path, expected_paths[i])


def compare_paths_with_accessor(self, tree, models, expected_paths):
def compare_paths_with_accessor(self, tree, expected_paths):
for i, model in enumerate(self.models):
path = [(n.model, n.accessor_name) for n in tree._node_path(model)]
self.assertEqual(path, expected_paths[i])



class RouterTestCase(TestCase):
def setUp(self):
self.models = [A, B, C, D, E, F, G, H, I, J, K]
Expand All @@ -40,8 +38,7 @@ def test_default(self):
[B, D, E, J, K],
]

compare_paths(self, tree, models, expected_paths)

compare_paths(self, tree, expected_paths)

def test_required(self):
"D from C rather than B (default)"
Expand All @@ -55,8 +52,7 @@ def test_required(self):

tree = ModelTree(A, **kwargs)

self.assertEqual(tree._required_joins, {D: C})
self.assertEqual(tree._required_join_fields, {})
self.assertEqual(tree._required_joins, {(C, D): None})

self.assertTrue(tree._join_allowed(C, D))
self.assertFalse(tree._join_allowed(B, D))
Expand All @@ -75,22 +71,27 @@ def test_required(self):
[C, D, E, J, K],
]

compare_paths(self, tree, models, expected_paths)
compare_paths(self, tree, expected_paths)

def test_excluded(self):
"Prevent D from B (go through C)"

kwargs = {
'excluded_routes': [{
'target': 'tests.D',
'source': 'tests.B'
}],
'excluded_routes': [
{
'target': 'tests.D',
'source': 'tests.B'
},
# {
# 'target': 'tests.D',
# 'source': 'tests.F',
# }
],
}

tree = ModelTree(A, **kwargs)

self.assertEqual(tree._excluded_joins, {D: B})
self.assertEqual(tree._excluded_join_fields, {})
self.assertEqual(tree._excluded_joins, {(B, D): None})

self.assertTrue(tree._join_allowed(C, D))
self.assertFalse(tree._join_allowed(B, D))
Expand All @@ -109,8 +110,7 @@ def test_excluded(self):
[C, D, E, J, K],
]

compare_paths(self, tree, models, expected_paths)

compare_paths(self, tree, expected_paths)

def test_required_long(self):
"G from H rather than D or B."
Expand All @@ -124,8 +124,7 @@ def test_required_long(self):

tree = ModelTree(A, **kwargs)

self.assertEqual(tree._required_joins, {G: H})
self.assertEqual(tree._required_join_fields, {})
self.assertEqual(tree._required_joins, {(H, G): None})

self.assertTrue(tree._join_allowed(H, G))
self.assertFalse(tree._join_allowed(B, G))
Expand All @@ -145,8 +144,7 @@ def test_required_long(self):
[B, D, E, J, K],
]

compare_paths(self, tree, models, expected_paths)

compare_paths(self, tree, expected_paths)

def test_required_excluded_combo_long(self):
"G from H (rather than D or B), not F from D, not D from B"
Expand All @@ -167,11 +165,9 @@ def test_required_excluded_combo_long(self):

tree = ModelTree(A, **kwargs)

self.assertEqual(tree._required_joins, {G: H})
self.assertEqual(tree._required_join_fields, {})
self.assertEqual(tree._required_joins, {(H, G): None})

self.assertEqual(tree._excluded_joins, {D: B, F: D})
self.assertEqual(tree._excluded_join_fields, {})
self.assertEqual(tree._excluded_joins, {(B, D): None, (D, F): None})

self.assertTrue(tree._join_allowed(C, D))
self.assertFalse(tree._join_allowed(B, D))
Expand All @@ -195,7 +191,7 @@ def test_required_excluded_combo_long(self):
[C, D, E, J, K],
]

compare_paths(self, tree, models, expected_paths)
compare_paths(self, tree, expected_paths)


class FieldRouterTestCase(TestCase):
Expand All @@ -216,11 +212,11 @@ def test_default(self):
[(B, 'b_set'), (G, 'g_set'), (H, 'h_set')],
[(B, 'b_set'), (G, 'g_set'), (H, 'h_set'), (I, 'i_set')],
[(B, 'b_set'), (D, 'd_set'), (E, 'e_set'), (J, 'j_set')],
[(B, 'b_set'), (D, 'd_set'), (E, 'e_set'), (J, 'j_set'), (K, 'k_set')],
[(B, 'b_set'), (D, 'd_set'), (E, 'e_set'), (J, 'j_set'),
(K, 'k_set')],
]

compare_paths_with_accessor(self, tree, models, expected_paths)

compare_paths_with_accessor(self, tree, expected_paths)

def test_required_field(self):
kwargs = {
Expand All @@ -244,7 +240,66 @@ def test_required_field(self):
[(B, 'b_set'), (G, 'g_set'), (H, 'h_set')],
[(B, 'b_set'), (G, 'g_set'), (H, 'h_set'), (I, 'i_set')],
[(B, 'b_set'), (D, 'd_set'), (E, 'e1_set'), (J, 'j_set')],
[(B, 'b_set'), (D, 'd_set'), (E, 'e1_set'), (J, 'j_set'), (K, 'k_set')],
[(B, 'b_set'), (D, 'd_set'), (E, 'e1_set'), (J, 'j_set'),
(K, 'k_set')],
]

compare_paths_with_accessor(self, tree, models, expected_paths)
compare_paths_with_accessor(self, tree, expected_paths)

def test_excluded_overlapping(self):
"Prevent D from B and D from F (go through C)"

kwargs = {
'excluded_routes': [
{
'target': 'tests.D',
'source': 'tests.B'
},
{
'target': 'tests.D',
'source': 'tests.F',
}
],
}

tree = ModelTree(A, **kwargs)

self.assertEqual(tree._excluded_joins, {(B, D): None,
(F, D): None})

self.assertTrue(tree._join_allowed(C, D))
self.assertFalse(tree._join_allowed(B, D))
self.assertFalse(tree._join_allowed(F, D))

expected_paths = [
[],
[B],
[C],
[C, D],
[C, D, E],
[C, D, F],
[B, G],
[B, G, H],
[B, G, H, I],
[C, D, E, J],
[C, D, E, J, K],
]

compare_paths(self, tree, expected_paths)

def test_required_collision(self):
"""Prevent two rules requiring the same target, e.g.
C->D and B->D"""

kwargs = {
'required_routes': [{
'target': 'tests.D',
'source': 'tests.C'
}, {
'target': 'tests.D',
'source': 'tests.B'
}],
}

with self.assertRaises(ValueError):
ModelTree(A, **kwargs)
13 changes: 13 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ class Meeting(models.Model):


# Router Test Models.. no relation to the above models
# The raw model tree looks like this:
# A
# / \
# C B
# | / \
# D G
# / \ |
# E F |
# \ / \ |
# J H
# | |
# K I

class A(models.Model):
pass

Expand Down

0 comments on commit f292ae8

Please sign in to comment.