From bb6143d265a048232e27ae04b540ed166d1d8293 Mon Sep 17 00:00:00 2001 From: Maksim Beliaev Date: Tue, 5 Apr 2022 10:03:08 +0200 Subject: [PATCH] added inline types to Response objects --- responses/__init__.py | 102 +++++++++++++++++++++++++----------------- 1 file changed, 61 insertions(+), 41 deletions(-) diff --git a/responses/__init__.py b/responses/__init__.py index 460d24e5..51724222 100644 --- a/responses/__init__.py +++ b/responses/__init__.py @@ -13,8 +13,11 @@ from typing import Any from typing import Callable from typing import Dict +from typing import Iterable from typing import Iterator +from typing import Mapping from typing import Optional +from typing import Tuple from typing import Union from warnings import warn @@ -54,6 +57,7 @@ # Block of type annotations _Body = Union[str, BaseException, "Response", BufferedReader, bytes] +_MatcherIterable = Iterable[Callable[[Any], Callable[..., Any]]] Call = namedtuple("Call", ["request", "response"]) _real_send = HTTPAdapter.send @@ -275,7 +279,7 @@ def _handle_body( data = BytesIO(body) - def is_closed(): + def is_closed() -> bool: """ Real Response uses HTTPResponse as body object. Thus, when method is_closed is called first to check if there is any more @@ -305,23 +309,29 @@ def is_closed(): class BaseResponse(object): - passthrough = False - content_type = None - headers = None - stream = False + passthrough: bool = False + content_type: Optional[str] = None + headers: Optional[Mapping[str, str]] = None + stream: bool = False - def __init__(self, method, url, match_querystring=None, match=()): - self.method = method + def __init__( + self, + method: str, + url: Union[Pattern[str], str], + match_querystring: Union[bool, object] = None, + match: "_MatcherIterable" = (), + ) -> None: + self.method: str = method # ensure the url has a default path set if the url is a string - self.url = _ensure_url_default_path(url) + self.url: Union[Pattern[str], str] = _ensure_url_default_path(url) if self._should_match_querystring(match_querystring): match = tuple(match) + (_query_string_matcher(urlsplit(self.url).query),) - self.match = match - self.call_count = 0 + self.match: "_MatcherIterable" = match + self.call_count: int = 0 - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, BaseResponse): return False @@ -336,10 +346,12 @@ def __eq__(self, other): return self_url == other_url - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) - def _should_match_querystring(self, match_querystring_argument): + def _should_match_querystring( + self, match_querystring_argument: Union[bool, object] + ) -> Union[bool, object]: if isinstance(self.url, Pattern): # the old default from <= 0.9.0 return False @@ -358,7 +370,7 @@ def _should_match_querystring(self, match_querystring_argument): return bool(urlsplit(self.url).query) - def _url_matches(self, url, other): + def _url_matches(self, url: Union[Pattern[str], str], other: str) -> bool: if isinstance(url, str): if _has_unicode(url): url = _clean_unicode(url) @@ -372,7 +384,9 @@ def _url_matches(self, url, other): return False @staticmethod - def _req_attr_matches(match, request): + def _req_attr_matches( + match: "_MatcherIterable", request: "PreparedRequest" + ) -> Tuple[bool, str]: for matcher in match: valid, reason = matcher(request) if not valid: @@ -380,7 +394,7 @@ def _req_attr_matches(match, request): return True, "" - def get_headers(self): + def get_headers(self) -> HTTPHeaderDict: headers = HTTPHeaderDict() # Duplicate headers are legal if self.content_type is not None: headers["Content-Type"] = self.content_type @@ -388,10 +402,10 @@ def get_headers(self): headers.extend(self.headers) return headers - def get_response(self, request): + def get_response(self, request: "PreparedRequest") -> None: raise NotImplementedError - def matches(self, request): + def matches(self, request: "PreparedRequest") -> Tuple[bool, str]: if request.method != self.method: return False, "Method does not match" @@ -408,17 +422,17 @@ def matches(self, request): class Response(BaseResponse): def __init__( self, - method, - url, - body="", - json=None, - status=200, - headers=None, - stream=None, - content_type=_UNSET, - auto_calculate_content_length=False, + method: str, + url: Union[Pattern[str], str], + body: _Body = "", + json: Optional[Any] = None, + status: int = 200, + headers: Optional[Mapping[str, str]] = None, + stream: bool = None, + content_type: Optional[str] = _UNSET, + auto_calculate_content_length: bool = False, **kwargs, - ): + ) -> None: # if we were passed a `json` argument, # override the body and content_type if json is not None: @@ -433,9 +447,9 @@ def __init__( else: content_type = "text/plain" - self.body = body - self.status = status - self.headers = headers + self.body: _Body = body + self.status: int = status + self.headers: Optional[Mapping[str, str]] = headers if stream is not None: warn( @@ -443,12 +457,12 @@ def __init__( DeprecationWarning, ) - self.stream = stream - self.content_type = content_type - self.auto_calculate_content_length = auto_calculate_content_length + self.stream: bool = stream + self.content_type: Optional[str] = content_type + self.auto_calculate_content_length: bool = auto_calculate_content_length super().__init__(method, url, **kwargs) - def get_response(self, request): + def get_response(self, request: "PreparedRequest") -> HTTPResponse: if self.body and isinstance(self.body, Exception): raise self.body @@ -487,8 +501,14 @@ def __repr__(self): class CallbackResponse(BaseResponse): def __init__( - self, method, url, callback, stream=None, content_type="text/plain", **kwargs - ): + self, + method: str, + url: Union[Pattern[str], str], + callback: Callable[[Any], Any], + stream: bool = None, + content_type: Optional[str] = "text/plain", + **kwargs, + ) -> None: self.callback = callback if stream is not None: @@ -496,11 +516,11 @@ def __init__( "stream argument is deprecated. Use stream parameter in request directly", DeprecationWarning, ) - self.stream = stream - self.content_type = content_type + self.stream: bool = stream + self.content_type: Optional[str] = content_type super().__init__(method, url, **kwargs) - def get_response(self, request): + def get_response(self, request: "PreparedRequest") -> HTTPResponse: headers = self.get_headers() result = self.callback(request) @@ -538,7 +558,7 @@ def get_response(self, request): class PassthroughResponse(BaseResponse): - passthrough = True + passthrough: bool = True class OriginalResponseShim(object):