Skip to content
This repository has been archived by the owner on Sep 12, 2021. It is now read-only.

Commit

Permalink
support chaining
Browse files Browse the repository at this point in the history
  • Loading branch information
jrydberg committed Apr 25, 2012
1 parent a8e6d3e commit 4adbc11
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
26 changes: 20 additions & 6 deletions mockito/invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# coding: utf-8

import matchers
from verification import Times


__copyright__ = "Copyright 2008-2010, Mockito Contributors"
__license__ = "MIT"
Expand All @@ -21,7 +23,9 @@ def __init__(self, mock, method_name):
self.named_params = {}
self.answers = []
self.strict = mock.strict

from mockito import mock
self.chain = mock(chainable=True)

def _remember_params(self, params, named_params):
self.params = params
self.named_params = named_params
Expand Down Expand Up @@ -64,13 +68,13 @@ class RememberedInvocation(Invocation):
def __call__(self, *params, **named_params):
self._remember_params(params, named_params)
self.mock.remember(self)

for matching_invocation in self.mock.stubbed_invocations:
if matching_invocation.matches(self):
return matching_invocation.answer_first()

return None
return self.chain if self.mock.chainable else None

class RememberedProxyInvocation(Invocation):
'''Remeber params and proxy to method of original object.
Expand Down Expand Up @@ -99,13 +103,17 @@ def __call__(self, *params, **named_params):

for invocation in matched_invocations:
invocation.verified = True
if self.mock.chainable:
invocation.chain.verification = Times(1)
return invocation.chain


class StubbedInvocation(MatchingInvocation):
def __init__(self, *params):
super(StubbedInvocation, self).__init__(*params)
if self.mock.strict:
self.ensure_mocked_object_has_method(self.method_name)

def ensure_mocked_object_has_method(self, method_name):
if not self.mock.has_method(method_name):
raise InvocationError("You tried to stub a method '%s' the object (%s) doesn't have."
Expand All @@ -120,7 +128,7 @@ def stub_with(self, answer):
self.answers.append(answer)
self.mock.stub(self.method_name)
self.mock.finish_stubbing(self)

class AnswerSelector(object):
def __init__(self, invocation):
self.invocation = invocation
Expand All @@ -136,6 +144,12 @@ def thenRaise(self, *exceptions):
self.__then(Raise(exception))
return self

def __getattr__(self, name):
return_value = self.invocation.chain
self.__then(Return(return_value))
return_value.expect_stubbing()
return getattr(return_value, name)

def __then(self, answer):
if not self.answer:
self.answer = CompositeAnswer(answer)
Expand Down
3 changes: 2 additions & 1 deletion mockito/mocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def fromMethod(meth):
import inspect

class mock(TestDouble):
def __init__(self, mocked_obj=None, strict=True):
def __init__(self, mocked_obj=None, strict=True, chainable=False):
self.invocations = []
self.stubbed_invocations = []
self.original_methods = []
Expand All @@ -95,6 +95,7 @@ def __init__(self, mocked_obj=None, strict=True):
self.zi = isinstance(mocked_obj, zi.interface.InterfaceClass)
else:
self.zi = False
self.chainable = chainable

mock_registry.register(self)

Expand Down
11 changes: 11 additions & 0 deletions mockito_test/stubbing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,17 @@ def bar_name(a):
else:
self.assertTrue(False, "StubbingError not raised")

def testChainableStubs(self):
person = mock(chainable=True)
person.needs().help(10)
verify(person).needs().help(10)

that = mock(chainable=True)
when(that).a().b(10).thenReturn(20)
x = that.a().b(10)
self.assertEquals(x, 20)


# TODO: verify after stubbing and vice versa

if __name__ == '__main__':
Expand Down

0 comments on commit 4adbc11

Please sign in to comment.