Skip to content
This repository has been archived by the owner on Apr 16, 2023. It is now read-only.

Commit

Permalink
Add shortcut to annotate and filter Routes by level
Browse files Browse the repository at this point in the history
  • Loading branch information
meshy committed Jan 18, 2018
1 parent 3a60756 commit 31745dc
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Added

- Added `Route.objects.with_level()`.
- Added `Route.get_subclasses()`.
- Added `TemplateHandler`. A simpler handler that requires only a template.
This is the new default for `Route.handler_class`.
Expand Down
24 changes: 24 additions & 0 deletions conman/routes/expressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from django.db.models.expressions import Func


class CharCount(Func):
"""
Count the occurrences of a char within a field.
Works by finding the difference in length between the whole string, and the
string with the char removed.
"""
template = "CHAR_LENGTH(%(field)s) - CHAR_LENGTH(REPLACE(%(field)s, '%(char)s', ''))"

def __init__(self, field, *, char, **extra):
"""
Add some validation to the invocation.
"Char" must always:
- be passed as a keyword argument
- be exactly one character.
"""
if len(char) != 1:
raise ValueError('CharCount must count exactly one char.')
super().__init__(field, char=char, **extra)
14 changes: 14 additions & 0 deletions conman/routes/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from polymorphic.managers import PolymorphicManager

from .exceptions import InvalidURL
from .expressions import CharCount
from .utils import split_path


Expand Down Expand Up @@ -72,3 +73,16 @@ def move_branch(self, old_url, new_url):
Value(new_url),
Substr('url', len(old_url) + 1), # 1 indexed
))

def with_level(self, level=None):
"""
Annotate the queryset with the level of each item.
The level reflects the number of forward slashes in the path.
If "level" is passed in, the queryset will be filtered by the level.
"""
qs = self.annotate(level=CharCount('url', char='/'))
if level is None:
return qs
return qs.filter(level=level)
43 changes: 43 additions & 0 deletions tests/routes/test_expressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from django.test import TestCase

from conman.routes.expressions import CharCount
from conman.routes.models import Route

from .factories import RouteFactory


class TestCharCount(TestCase):
"""Tests for CharCount."""
def test_query(self):
"""Match the exact value of the generated query."""
# The "only" here is handy to keep the query as short as possible.
qs = Route.objects.only('id').annotate(level=CharCount('url', char='/'))
# Excuse the line wrapping here -- I wasn't sure of a nice way to do it.
# I decided it was better to just keep it simple.
expected = (
'SELECT "routes_route"."id", ' +
'CHAR_LENGTH("routes_route"."url") - ' +
'CHAR_LENGTH(REPLACE("routes_route"."url", \'/\', \'\')) AS "level" ' +
'FROM "routes_route"'
)
self.assertEqual(str(qs.query), expected)

def test_annotation(self):
"""Test the expression can be used for annotation."""
RouteFactory.create(url='/sixth/level/path/including/root/')
route = Route.objects.annotate(level=CharCount('url', char='/')).get()
self.assertEqual(route.level, 6) # The number of "/" in the path.

def test_calling_format(self):
"""Ensure the 'char' argument is always a keyword-arg."""
with self.assertRaises(TypeError):
CharCount('url', 'unacceptable')

def test_char_length(self):
"""Ensure 'char' length is always 1."""
msg = 'CharCount must count exactly one char.'
with self.assertRaisesMessage(ValueError, msg):
CharCount('url', char='no')

with self.assertRaisesMessage(ValueError, msg):
CharCount('url', char='')
21 changes: 21 additions & 0 deletions tests/routes/test_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,24 @@ def test_descendant_destination_occupied(self):
self.assertEqual(route.url, original_url)
child.refresh_from_db()
self.assertEqual(child.url, original_url + 'child/')


class RouteManagerWithPathTest(TestCase):
"""Test Route.objects.with_level."""
def test_no_level_passed(self):
"""No level passed, so items are annotated, but no filter is applied."""
RouteFactory.create(url='/')
RouteFactory.create(url='/branch/')
RouteFactory.create(url='/branch/leaf/')
result = Route.objects.with_level().order_by('level')
self.assertEqual(result[0].level, 1)
self.assertEqual(result[1].level, 2)
self.assertEqual(result[2].level, 3)

def test_level_passed(self):
"""When passing a level, the filter is automatically applied."""
RouteFactory.create(url='/') # Not in QS.
branch = RouteFactory.create(url='/branch/')
RouteFactory.create(url='/branch/leaf/') # Not in QS.
result = Route.objects.with_level(2)
self.assertCountEqual(result, [branch])

0 comments on commit 31745dc

Please sign in to comment.