From bf30737282793b9290c1f064c85db5769d2d4ccf Mon Sep 17 00:00:00 2001 From: Maksim Beliaev Date: Mon, 4 Apr 2022 22:18:45 +0200 Subject: [PATCH] Add more useful return values (#529) * `add`, `upsert`, `replace` methods return registered response. * `remove` method returns list of removed responses. * Update `add` method on `FirstMatchRegistry` object to do a deepcopy if the same in-memory object is already in the list. --- CHANGES | 2 ++ README.rst | 46 +++++++++++++++++++++++++++++++ responses/__init__.py | 28 +++++++++---------- responses/__init__.pyi | 10 +++---- responses/registries.py | 16 +++++++---- responses/tests/test_responses.py | 44 +++++++++++++++++++++++++++++ 6 files changed, 122 insertions(+), 24 deletions(-) diff --git a/CHANGES b/CHANGES index a7d667be..e9167197 100644 --- a/CHANGES +++ b/CHANGES @@ -3,6 +3,8 @@ * Add `threading.Lock()` to allow `responses` working with `threading` module. * Removed internal `_cookies_from_headers` function +* Now `add`, `upsert`, `replace` methods return registered response. + `remove` method returns list of removed responses. 0.20.0 ------ diff --git a/README.rst b/README.rst index 51c029e7..94b960f6 100644 --- a/README.rst +++ b/README.rst @@ -805,6 +805,42 @@ the ``assert_all_requests_are_fired`` value: Assert Request Call Count ------------------------- +Assert based on ``Response`` object +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Each ``Response`` object has ``call_count`` attribute that could be inspected +to check how many times each request was matched. + +.. code-block:: python + + @responses.activate + def test_call_count_with_matcher(): + + rsp = responses.add( + responses.GET, + "http://www.example.com", + match=(matchers.query_param_matcher({}),), + ) + rsp2 = responses.add( + responses.GET, + "http://www.example.com", + match=(matchers.query_param_matcher({"hello": "world"}),), + status=777, + ) + requests.get("http://www.example.com") + resp1 = requests.get("http://www.example.com") + requests.get("http://www.example.com?hello=world") + resp2 = requests.get("http://www.example.com?hello=world") + + assert resp1.status_code == 200 + assert resp2.status_code == 777 + + assert rsp.call_count == 2 + assert rsp2.call_count == 2 + +Assert based on the exact URL +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + Assert that the request was called exactly n times. .. code-block:: python @@ -824,6 +860,16 @@ Assert that the request was called exactly n times. responses.assert_call_count("http://example.com", 1) assert "Expected URL 'http://example.com' to be called 1 times. Called 2 times." in str(excinfo.value) + @responses.activate + def test_assert_call_count_always_match_qs(): + responses.add(responses.GET, "http://www.example.com") + requests.get("http://www.example.com") + requests.get("http://www.example.com?hello=world") + + # One call on each url, querystring is matched by default + responses.assert_call_count("http://www.example.com", 1) is True + responses.assert_call_count("http://www.example.com?hello=world", 1) is True + Multiple Responses ------------------ diff --git a/responses/__init__.py b/responses/__init__.py index 63b68b58..460d24e5 100644 --- a/responses/__init__.py +++ b/responses/__init__.py @@ -648,8 +648,7 @@ def add( """ if isinstance(method, BaseResponse): - self._registry.add(method) - return + return self._registry.add(method) if adding_headers is not None: kwargs.setdefault("headers", adding_headers) @@ -661,28 +660,29 @@ def add( " Using the `content_type` kwarg is recommended." ) - self._registry.add(Response(method=method, url=url, body=body, **kwargs)) + response = Response(method=method, url=url, body=body, **kwargs) + return self._registry.add(response) def delete(self, *args, **kwargs): - self.add(DELETE, *args, **kwargs) + return self.add(DELETE, *args, **kwargs) def get(self, *args, **kwargs): - self.add(GET, *args, **kwargs) + return self.add(GET, *args, **kwargs) def head(self, *args, **kwargs): - self.add(HEAD, *args, **kwargs) + return self.add(HEAD, *args, **kwargs) def options(self, *args, **kwargs): - self.add(OPTIONS, *args, **kwargs) + return self.add(OPTIONS, *args, **kwargs) def patch(self, *args, **kwargs): - self.add(PATCH, *args, **kwargs) + return self.add(PATCH, *args, **kwargs) def post(self, *args, **kwargs): - self.add(POST, *args, **kwargs) + return self.add(POST, *args, **kwargs) def put(self, *args, **kwargs): - self.add(PUT, *args, **kwargs) + return self.add(PUT, *args, **kwargs) def add_passthru(self, prefix): """ @@ -718,7 +718,7 @@ def remove(self, method_or_response=None, url=None): else: response = BaseResponse(method=method_or_response, url=url) - self._registry.remove(response) + return self._registry.remove(response) def replace(self, method_or_response=None, url=None, body="", *args, **kwargs): """ @@ -735,7 +735,7 @@ def replace(self, method_or_response=None, url=None, body="", *args, **kwargs): else: response = Response(method=method_or_response, url=url, body=body, **kwargs) - self._registry.replace(response) + return self._registry.replace(response) def upsert(self, method_or_response=None, url=None, body="", *args, **kwargs): """ @@ -748,9 +748,9 @@ def upsert(self, method_or_response=None, url=None, body="", *args, **kwargs): >>> responses.upsert(responses.GET, 'http://example.org', json={'data': 2}) """ try: - self.replace(method_or_response, url, body, *args, **kwargs) + return self.replace(method_or_response, url, body, *args, **kwargs) except ValueError: - self.add(method_or_response, url, body, *args, **kwargs) + return self.add(method_or_response, url, body, *args, **kwargs) def add_callback( self, diff --git a/responses/__init__.pyi b/responses/__init__.pyi index c95ee6f5..89a2bfa3 100644 --- a/responses/__init__.pyi +++ b/responses/__init__.pyi @@ -210,7 +210,7 @@ class _Add(Protocol): adding_headers: HeaderSet = ..., match_querystring: bool = ..., match: MatcherIterable = ..., - ) -> None: ... + ) -> BaseResponse: ... class _Shortcut(Protocol): def __call__( @@ -226,7 +226,7 @@ class _Shortcut(Protocol): adding_headers: HeaderSet = ..., match_querystring: bool = ..., match: MatcherIterable = ..., - ) -> None: ... + ) -> BaseResponse: ... class _AddCallback(Protocol): def __call__( @@ -249,7 +249,7 @@ class _Remove(Protocol): self, method_or_response: Optional[Union[str, BaseResponse]] = ..., url: Optional[Union[Pattern[str], str]] = ..., - ) -> None: ... + ) -> List[BaseResponse]: ... class _Replace(Protocol): def __call__( @@ -265,7 +265,7 @@ class _Replace(Protocol): adding_headers: HeaderSet = ..., match_querystring: bool = ..., match: MatcherIterable = ..., - ) -> None: ... + ) -> BaseResponse: ... class _Upsert(Protocol): def __call__( @@ -281,7 +281,7 @@ class _Upsert(Protocol): adding_headers: HeaderSet = ..., match_querystring: bool = ..., match: MatcherIterable = ..., - ) -> None: ... + ) -> BaseResponse: ... class _Registered(Protocol): def __call__(self) -> List[Response]: ... diff --git a/responses/registries.py b/responses/registries.py index 9bad2744..049df23f 100644 --- a/responses/registries.py +++ b/responses/registries.py @@ -46,19 +46,24 @@ def find( match_failed_reasons.append(reason) return found_match, match_failed_reasons - def add(self, response: "BaseResponse") -> None: - if response in self.registered: - # if user adds multiple responses that reference the same instance + def add(self, response: "BaseResponse") -> "BaseResponse": + if any(response is resp for resp in self.registered): + # if user adds multiple responses that reference the same instance. + # do a comparison by memory allocation address. # see https://github.com/getsentry/responses/issues/479 response = copy.deepcopy(response) self.registered.append(response) + return response - def remove(self, response: "BaseResponse") -> None: + def remove(self, response: "BaseResponse") -> List["BaseResponse"]: + removed_responses = [] while response in self.registered: self.registered.remove(response) + removed_responses.append(response) + return removed_responses - def replace(self, response: "BaseResponse") -> None: + def replace(self, response: "BaseResponse") -> "BaseResponse": try: index = self.registered.index(response) except ValueError: @@ -66,6 +71,7 @@ def replace(self, response: "BaseResponse") -> None: "Response is not registered for URL {}".format(response.url) ) self.registered[index] = response + return response class OrderedRegistry(FirstMatchRegistry): diff --git a/responses/tests/test_responses.py b/responses/tests/test_responses.py index d0031c94..47479bbc 100644 --- a/responses/tests/test_responses.py +++ b/responses/tests/test_responses.py @@ -1835,6 +1835,50 @@ def run(): assert_reset() +def test_call_count_with_matcher(): + @responses.activate + def run(): + rsp = responses.add( + responses.GET, + "http://www.example.com", + match=(matchers.query_param_matcher({}),), + ) + rsp2 = responses.add( + responses.GET, + "http://www.example.com", + match=(matchers.query_param_matcher({"hello": "world"}),), + status=777, + ) + requests.get("http://www.example.com") + resp1 = requests.get("http://www.example.com") + requests.get("http://www.example.com?hello=world") + resp2 = requests.get("http://www.example.com?hello=world") + + assert resp1.status_code == 200 + assert resp2.status_code == 777 + + assert rsp.call_count == 2 + assert rsp2.call_count == 2 + + run() + assert_reset() + + +def test_call_count_without_matcher(): + @responses.activate + def run(): + rsp = responses.add(responses.GET, "http://www.example.com") + requests.get("http://www.example.com") + requests.get("http://www.example.com") + requests.get("http://www.example.com?hello=world") + requests.get("http://www.example.com?hello=world") + + assert rsp.call_count == 4 + + run() + assert_reset() + + def test_fail_request_error(): """ Validate that exception is raised if request URL/Method/kwargs don't match