Skip to content

Commit

Permalink
added inline types to Response objects (#535)
Browse files Browse the repository at this point in the history
  • Loading branch information
beliaev-maksim committed Apr 11, 2022
1 parent bdc5eff commit df920c0
Showing 1 changed file with 62 additions and 42 deletions.
104 changes: 62 additions & 42 deletions responses/__init__.py
Expand Up @@ -13,8 +13,11 @@
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import Mapping
from typing import Optional
from typing import Tuple
from typing import Union
from warnings import warn

Expand Down Expand Up @@ -55,6 +58,7 @@

# Block of type annotations
_Body = Union[str, BaseException, "Response", BufferedReader, bytes]
_MatcherIterable = Iterable[Callable[[Any], Callable[..., Any]]]

Call = namedtuple("Call", ["request", "response"])
_real_send = HTTPAdapter.send
Expand Down Expand Up @@ -295,7 +299,7 @@ def _handle_body(

data = BytesIO(body)

def is_closed():
def is_closed() -> bool:
"""
Real Response uses HTTPResponse as body object.
Thus, when method is_closed is called first to check if there is any more
Expand Down Expand Up @@ -325,23 +329,29 @@ def is_closed():


class BaseResponse(object):
passthrough = False
content_type = None
headers = None
stream = False
passthrough: bool = False
content_type: Optional[str] = None
headers: Optional[Mapping[str, str]] = None
stream: bool = False

def __init__(self, method, url, match_querystring=None, match=()):
self.method = method
def __init__(
self,
method: str,
url: "Union[Pattern[str], str]",
match_querystring: Union[bool, object] = None,
match: "_MatcherIterable" = (),
) -> None:
self.method: str = method
# ensure the url has a default path set if the url is a string
self.url = _ensure_url_default_path(url)
self.url: "Union[Pattern[str], str]" = _ensure_url_default_path(url)

if self._should_match_querystring(match_querystring):
match = tuple(match) + (_query_string_matcher(urlsplit(self.url).query),)

self.match = match
self.call_count = 0
self.match: "_MatcherIterable" = match
self.call_count: int = 0

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if not isinstance(other, BaseResponse):
return False

Expand All @@ -356,10 +366,12 @@ def __eq__(self, other):

return self_url == other_url

def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)

def _should_match_querystring(self, match_querystring_argument):
def _should_match_querystring(
self, match_querystring_argument: Union[bool, object]
) -> Union[bool, object]:
if isinstance(self.url, Pattern):
# the old default from <= 0.9.0
return False
Expand All @@ -378,7 +390,7 @@ def _should_match_querystring(self, match_querystring_argument):

return bool(urlsplit(self.url).query)

def _url_matches(self, url, other):
def _url_matches(self, url: "Union[Pattern[str], str]", other: str) -> bool:
if isinstance(url, str):
if _has_unicode(url):
url = _clean_unicode(url)
Expand All @@ -392,26 +404,28 @@ def _url_matches(self, url, other):
return False

@staticmethod
def _req_attr_matches(match, request):
def _req_attr_matches(
match: "_MatcherIterable", request: "PreparedRequest"
) -> Tuple[bool, str]:
for matcher in match:
valid, reason = matcher(request)
if not valid:
return False, reason

return True, ""

def get_headers(self):
def get_headers(self) -> HTTPHeaderDict:
headers = HTTPHeaderDict() # Duplicate headers are legal
if self.content_type is not None:
headers["Content-Type"] = self.content_type
if self.headers:
headers.extend(self.headers)
return headers

def get_response(self, request):
def get_response(self, request: "PreparedRequest") -> None:
raise NotImplementedError

def matches(self, request):
def matches(self, request: "PreparedRequest") -> Tuple[bool, str]:
if request.method != self.method:
return False, "Method does not match"

Expand All @@ -428,17 +442,17 @@ def matches(self, request):
class Response(BaseResponse):
def __init__(
self,
method,
url,
body="",
json=None,
status=200,
headers=None,
stream=None,
content_type=_UNSET,
auto_calculate_content_length=False,
method: str,
url: "Union[Pattern[str], str]",
body: _Body = "",
json: Optional[Any] = None,
status: int = 200,
headers: Optional[Mapping[str, str]] = None,
stream: bool = None,
content_type: Optional[str] = _UNSET,
auto_calculate_content_length: bool = False,
**kwargs,
):
) -> None:
# if we were passed a `json` argument,
# override the body and content_type
if json is not None:
Expand All @@ -453,22 +467,22 @@ def __init__(
else:
content_type = "text/plain"

self.body = body
self.status = status
self.headers = headers
self.body: _Body = body
self.status: int = status
self.headers: Optional[Mapping[str, str]] = headers

if stream is not None:
warn(
"stream argument is deprecated. Use stream parameter in request directly",
DeprecationWarning,
)

self.stream = stream
self.content_type = content_type
self.auto_calculate_content_length = auto_calculate_content_length
self.stream: bool = stream
self.content_type: Optional[str] = content_type
self.auto_calculate_content_length: bool = auto_calculate_content_length
super().__init__(method, url, **kwargs)

def get_response(self, request):
def get_response(self, request: "PreparedRequest") -> HTTPResponse:
if self.body and isinstance(self.body, Exception):
raise self.body

Expand All @@ -493,7 +507,7 @@ def get_response(self, request):
preload_content=False,
)

def __repr__(self):
def __repr__(self) -> str:
return (
"<Response(url='{url}' status={status} "
"content_type='{content_type}' headers='{headers}')>".format(
Expand All @@ -507,20 +521,26 @@ def __repr__(self):

class CallbackResponse(BaseResponse):
def __init__(
self, method, url, callback, stream=None, content_type="text/plain", **kwargs
):
self,
method: str,
url: "Union[Pattern[str], str]",
callback: Callable[[Any], Any],
stream: bool = None,
content_type: Optional[str] = "text/plain",
**kwargs,
) -> None:
self.callback = callback

if stream is not None:
warn(
"stream argument is deprecated. Use stream parameter in request directly",
DeprecationWarning,
)
self.stream = stream
self.content_type = content_type
self.stream: bool = stream
self.content_type: Optional[str] = content_type
super().__init__(method, url, **kwargs)

def get_response(self, request):
def get_response(self, request: "PreparedRequest") -> HTTPResponse:
headers = self.get_headers()

result = self.callback(request)
Expand Down Expand Up @@ -558,7 +578,7 @@ def get_response(self, request):


class PassthroughResponse(BaseResponse):
passthrough = True
passthrough: bool = True


class OriginalResponseShim(object):
Expand Down

0 comments on commit df920c0

Please sign in to comment.