Permalink
Browse files

Improve handling of DoesNotExist and MultipleObjectsReturned cases. M…

…oved shared mock components for manager/queryset into SharedMock
  • Loading branch information...
1 parent 2fe2d6f commit 74c474a221e81bffe6d529dae2e06c6543d1c091 @dcramer committed Feb 21, 2012
Showing with 73 additions and 99 deletions.
  1. +4 −48 mock_django/managers.py
  2. +6 −50 mock_django/query.py
  3. +47 −0 mock_django/shared.py
  4. +1 −1 setup.py
  5. +15 −0 tests/mock_django/managers/tests.py
View
@@ -8,56 +8,12 @@
import mock
from .query import QuerySetMock
+from .shared import SharedMock
__all__ = ('ManagerMock',)
-class _ManagerMock(mock.MagicMock):
- def __init__(self, *args, **kwargs):
- super(_ManagerMock, self).__init__(*args, **kwargs)
- parent = mock.MagicMock()
- parent.child = self
- self.__parent = parent
-
- def _get_child_mock(self, **kwargs):
- name = kwargs.get('name', '')
- if name[:2] == name[-2:] == '__':
- return super(_ManagerMock, self)._get_child_mock(**kwargs)
- return self
-
- def __getattr__(self, name):
- result = super(_ManagerMock, self).__getattr__(name)
- if result is self:
- result._mock_name = result._mock_new_name = name
- return result
-
- def assert_chain_calls(self, *calls):
- """
- Asserts that a chained method was called (parents in the chain do not
- matter, nor are they tracked).
-
- >>> obj.assert_chain_calls(call.filter(foo='bar'))
- >>> obj.assert_chain_calls(call.select_related('baz'))
- """
- all_calls = self.__parent.mock_calls[:]
-
- not_found = []
- for kall in calls:
- try:
- all_calls.remove(kall)
- except ValueError:
- not_found.append(kall)
- if not_found:
- if self.__parent.mock_calls:
- message = '%r not all found in call list, %d other(s) were:\n%r' % (not_found, len(self.__parent.mock_calls),
- self.__parent.mock_calls)
- else:
- message = 'no calls were found'
-
- raise AssertionError(message)
-
-
def ManagerMock(manager, *return_value):
"""
Set the results to two items:
@@ -70,9 +26,9 @@ def ManagerMock(manager, *return_value):
>>> objects = ManagerMock(Post.objects, Exception())
"""
- def make_get_query_set(self, actual_model):
+ def make_get_query_set(self, model):
def _get(*a, **k):
- return QuerySetMock(actual_model, *return_value)
+ return QuerySetMock(model, *return_value)
return _get
actual_model = getattr(manager, 'model', None)
@@ -81,7 +37,7 @@ def _get(*a, **k):
else:
model = mock.MagicMock()
- m = _ManagerMock()
+ m = SharedMock()
m.model = model
m.get_query_set = make_get_query_set(m, actual_model)
m.get = m.get_query_set().get
View
@@ -7,55 +7,11 @@
"""
import mock
+from .shared import SharedMock
__all__ = ('QuerySetMock',)
-class _QuerySetMock(mock.MagicMock):
- def __init__(self, *args, **kwargs):
- super(_QuerySetMock, self).__init__(*args, **kwargs)
- parent = mock.MagicMock()
- parent.child = self
- self.__parent = parent
-
- def _get_child_mock(self, **kwargs):
- name = kwargs.get('name', '')
- if name[:2] == name[-2:] == '__':
- return super(_QuerySetMock, self)._get_child_mock(**kwargs)
- return self
-
- def __getattr__(self, name):
- result = super(_QuerySetMock, self).__getattr__(name)
- if result is self:
- result._mock_name = result._mock_new_name = name
- return result
-
- def assert_chain_calls(self, *calls):
- """
- Asserts that a chained method was called (parents in the chain do not
- matter, nor are they tracked).
-
- >>> obj.assert_chain_calls(call.filter(foo='bar'))
- >>> obj.assert_chain_calls(call.select_related('baz'))
- """
- all_calls = self.__parent.mock_calls[:]
-
- not_found = []
- for kall in calls:
- try:
- all_calls.remove(kall)
- except ValueError:
- not_found.append(kall)
- if not_found:
- if self.__parent.mock_calls:
- message = '%r not all found in call list, %d other(s) were:\n%r' % (not_found, len(self.__parent.mock_calls),
- self.__parent.mock_calls)
- else:
- message = 'no calls were found'
-
- raise AssertionError(message)
-
-
def QuerySetMock(model, *return_value):
"""
Set the results to two items:
@@ -68,15 +24,15 @@ def QuerySetMock(model, *return_value):
>>> objects = QuerySetMock(Post, Exception())
"""
- def make_get(self):
+ def make_get(self, model):
def _get(*a, **k):
results = list(self)
if len(results) > 1:
- raise self.model.MultipleObjectsReturned
+ raise model.MultipleObjectsReturned
try:
return results[0]
except IndexError:
- raise self.model.DoesNotExist
+ raise model.DoesNotExist
return _get
def make_getitem(self):
@@ -106,12 +62,12 @@ def _iterator(*a, **k):
else:
model = mock.MagicMock()
- m = _QuerySetMock()
+ m = SharedMock()
m.__start = None
m.__stop = None
m.__iter__.side_effect = lambda: iter(m.iterator())
m.__getitem__.side_effect = make_getitem(m)
m.model = model
- m.get = make_get(m)
+ m.get = make_get(m, actual_model)
m.iterator.side_effect = make_iterator(m)
return m
View
@@ -0,0 +1,47 @@
+import mock
+
+
+class SharedMock(mock.MagicMock):
+ def __init__(self, *args, **kwargs):
+ super(SharedMock, self).__init__(*args, **kwargs)
+ parent = mock.MagicMock()
+ parent.child = self
+ self.__parent = parent
+
+ def _get_child_mock(self, **kwargs):
+ name = kwargs.get('name', '')
+ if name[:2] == name[-2:] == '__':
+ return super(SharedMock, self)._get_child_mock(**kwargs)
+ return self
+
+ def __getattr__(self, name):
+ result = super(SharedMock, self).__getattr__(name)
+ if result is self:
+ result._mock_name = result._mock_new_name = name
+ return result
+
+ def assert_chain_calls(self, *calls):
+ """
+ Asserts that a chained method was called (parents in the chain do not
+ matter, nor are they tracked).
+
+ >>> obj.assert_chain_calls(call.filter(foo='bar'))
+ >>> obj.assert_chain_calls(call.select_related('baz'))
+ """
+ all_calls = self.__parent.mock_calls[:]
+
+ not_found = []
+ for kall in calls:
+ try:
+ all_calls.remove(kall)
+ except ValueError:
+ not_found.append(kall)
+ if not_found:
+ if self.__parent.mock_calls:
+ message = '%r not all found in call list, %d other(s) were:\n%r' % (not_found, len(self.__parent.mock_calls),
+ self.__parent.mock_calls)
+ else:
+ message = 'no calls were found'
+
+ raise AssertionError(message)
+
View
@@ -2,7 +2,7 @@
setup(
name='mock-django',
- version='0.3.0',
+ version='0.4.0',
description='',
license='Apache License 2.0',
author='David Cramer',
@@ -3,10 +3,19 @@
from unittest2 import TestCase
+class Model(object):
+ class DoesNotExist(Exception):
+ pass
+
+ class MultipleObjectsReturned(Exception):
+ pass
+
+
def make_manager():
manager = mock.MagicMock(spec=(
'all', 'filter', 'order_by',
))
+ manager.model = Model
return manager
@@ -69,3 +78,9 @@ def test_getitem_get(self):
manager = make_manager()
inst = ManagerMock(manager, 'foo')
self.assertEquals(inst[0:1].get(), 'foo')
+
+ def test_get_raises_doesnotexist_with_queryset(self):
+ manager = make_manager()
+ inst = ManagerMock(manager)
+ queryset = inst.using('default.slave')[0:1]
+ self.assertRaises(manager.model.DoesNotExist, queryset.get)

0 comments on commit 74c474a

Please sign in to comment.