From 31745dcf9ce3f38adf0fcd47ee7d7bcb151af054 Mon Sep 17 00:00:00 2001 From: Charlie Denton Date: Thu, 18 Jan 2018 23:17:33 +0000 Subject: [PATCH] Add shortcut to annotate and filter Routes by level --- CHANGELOG.md | 1 + conman/routes/expressions.py | 24 ++++++++++++++++++ conman/routes/managers.py | 14 +++++++++++ tests/routes/test_expressions.py | 43 ++++++++++++++++++++++++++++++++ tests/routes/test_managers.py | 21 ++++++++++++++++ 5 files changed, 103 insertions(+) create mode 100644 conman/routes/expressions.py create mode 100644 tests/routes/test_expressions.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 87e64fe..945eca5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### Added +- Added `Route.objects.with_level()`. - Added `Route.get_subclasses()`. - Added `TemplateHandler`. A simpler handler that requires only a template. This is the new default for `Route.handler_class`. diff --git a/conman/routes/expressions.py b/conman/routes/expressions.py new file mode 100644 index 0000000..56aeb4d --- /dev/null +++ b/conman/routes/expressions.py @@ -0,0 +1,24 @@ +from django.db.models.expressions import Func + + +class CharCount(Func): + """ + Count the occurrences of a char within a field. + + Works by finding the difference in length between the whole string, and the + string with the char removed. + """ + template = "CHAR_LENGTH(%(field)s) - CHAR_LENGTH(REPLACE(%(field)s, '%(char)s', ''))" + + def __init__(self, field, *, char, **extra): + """ + Add some validation to the invocation. + + "Char" must always: + + - be passed as a keyword argument + - be exactly one character. + """ + if len(char) != 1: + raise ValueError('CharCount must count exactly one char.') + super().__init__(field, char=char, **extra) diff --git a/conman/routes/managers.py b/conman/routes/managers.py index 6fb5cd4..e13fa83 100644 --- a/conman/routes/managers.py +++ b/conman/routes/managers.py @@ -4,6 +4,7 @@ from polymorphic.managers import PolymorphicManager from .exceptions import InvalidURL +from .expressions import CharCount from .utils import split_path @@ -72,3 +73,16 @@ def move_branch(self, old_url, new_url): Value(new_url), Substr('url', len(old_url) + 1), # 1 indexed )) + + def with_level(self, level=None): + """ + Annotate the queryset with the level of each item. + + The level reflects the number of forward slashes in the path. + + If "level" is passed in, the queryset will be filtered by the level. + """ + qs = self.annotate(level=CharCount('url', char='/')) + if level is None: + return qs + return qs.filter(level=level) diff --git a/tests/routes/test_expressions.py b/tests/routes/test_expressions.py new file mode 100644 index 0000000..341b4ab --- /dev/null +++ b/tests/routes/test_expressions.py @@ -0,0 +1,43 @@ +from django.test import TestCase + +from conman.routes.expressions import CharCount +from conman.routes.models import Route + +from .factories import RouteFactory + + +class TestCharCount(TestCase): + """Tests for CharCount.""" + def test_query(self): + """Match the exact value of the generated query.""" + # The "only" here is handy to keep the query as short as possible. + qs = Route.objects.only('id').annotate(level=CharCount('url', char='/')) + # Excuse the line wrapping here -- I wasn't sure of a nice way to do it. + # I decided it was better to just keep it simple. + expected = ( + 'SELECT "routes_route"."id", ' + + 'CHAR_LENGTH("routes_route"."url") - ' + + 'CHAR_LENGTH(REPLACE("routes_route"."url", \'/\', \'\')) AS "level" ' + + 'FROM "routes_route"' + ) + self.assertEqual(str(qs.query), expected) + + def test_annotation(self): + """Test the expression can be used for annotation.""" + RouteFactory.create(url='/sixth/level/path/including/root/') + route = Route.objects.annotate(level=CharCount('url', char='/')).get() + self.assertEqual(route.level, 6) # The number of "/" in the path. + + def test_calling_format(self): + """Ensure the 'char' argument is always a keyword-arg.""" + with self.assertRaises(TypeError): + CharCount('url', 'unacceptable') + + def test_char_length(self): + """Ensure 'char' length is always 1.""" + msg = 'CharCount must count exactly one char.' + with self.assertRaisesMessage(ValueError, msg): + CharCount('url', char='no') + + with self.assertRaisesMessage(ValueError, msg): + CharCount('url', char='') diff --git a/tests/routes/test_managers.py b/tests/routes/test_managers.py index 7dd1d42..125096e 100644 --- a/tests/routes/test_managers.py +++ b/tests/routes/test_managers.py @@ -198,3 +198,24 @@ def test_descendant_destination_occupied(self): self.assertEqual(route.url, original_url) child.refresh_from_db() self.assertEqual(child.url, original_url + 'child/') + + +class RouteManagerWithPathTest(TestCase): + """Test Route.objects.with_level.""" + def test_no_level_passed(self): + """No level passed, so items are annotated, but no filter is applied.""" + RouteFactory.create(url='/') + RouteFactory.create(url='/branch/') + RouteFactory.create(url='/branch/leaf/') + result = Route.objects.with_level().order_by('level') + self.assertEqual(result[0].level, 1) + self.assertEqual(result[1].level, 2) + self.assertEqual(result[2].level, 3) + + def test_level_passed(self): + """When passing a level, the filter is automatically applied.""" + RouteFactory.create(url='/') # Not in QS. + branch = RouteFactory.create(url='/branch/') + RouteFactory.create(url='/branch/leaf/') # Not in QS. + result = Route.objects.with_level(2) + self.assertCountEqual(result, [branch])