Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Httpsession requests typing #2699

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
84 changes: 79 additions & 5 deletions locust/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import re
import time
from collections.abc import Generator
from contextlib import contextmanager
from typing import TYPE_CHECKING
from urllib.parse import urlparse, urlunparse

import requests
Expand All @@ -15,6 +15,38 @@

from .exception import CatchResponseError, LocustError, ResponseError

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Iterable, Mapping, MutableMapping
from typing import Any, TypedDict

from requests.cookies import RequestsCookieJar
from typing_extensions import Unpack

# Annotations below were generated using output from mypy.
# Mypy underneath uses information from the https://github.com/python/typeshed repo.

class RequestKwargs(TypedDict, total=False):
params: Any | None # simplified signature
data: Any | None # simplified signature
headers: Mapping[str, str | bytes | None] | None
cookies: RequestsCookieJar | MutableMapping[str, str] | None
files: Any | None # simplified signature
auth: Any | None # simplified signature
timeout: float | tuple[float, float] | tuple[float, None] | None
allow_redirects: bool
proxies: MutableMapping[str, str] | None
hooks: Mapping[str, Iterable[Callable[[Response], Any]] | Callable[[Response], Any]] | None
stream: bool | None
verify: bool | str | None
cert: str | tuple[str, str] | None
json: Any | None

class RESTKwargs(RequestKwargs, total=False):
name: str | None
catch_response: bool
context: dict


absolute_http_url_regexp = re.compile(r"^https?://", re.I)


Expand Down Expand Up @@ -94,7 +126,15 @@ def rename_request(self, name: str) -> Generator[None, None, None]:
finally:
self.request_name = None

def request(self, method, url, name=None, catch_response=False, context={}, **kwargs):
def request( # type: ignore[override]
self,
method: str | bytes,
url: str | bytes,
name: str | None = None,
catch_response: bool = False,
context: dict = {},
**kwargs: Unpack[RequestKwargs],
):
"""
Constructs and sends a :py:class:`requests.Request`.
Returns :py:class:`requests.Response` object.
Expand All @@ -108,7 +148,8 @@ def request(self, method, url, name=None, catch_response=False, context={}, **kw
response, even if the response code is ok (2xx). The opposite also works, one can use catch_response to catch a request
and then mark it as successful even if the response code was not (i.e 500 or 404).
:param params: (optional) Dictionary or bytes to be sent in the query string for the :class:`Request`.
:param data: (optional) Dictionary or bytes to send in the body of the :class:`Request`.
:param data: (optional) Dictionary, list of tuples, bytes, or file-like object to send in the body of the :class:`Request`.
:param json: (optional) json to send in the body of the :class:`Request`.
:param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.
:param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.
:param files: (optional) Dictionary of ``'filename': file-like-objects`` for multipart encoding upload.
Expand All @@ -117,9 +158,17 @@ def request(self, method, url, name=None, catch_response=False, context={}, **kw
:type timeout: float or tuple
:param allow_redirects: (optional) Set to True by default.
:type allow_redirects: bool
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
:param proxies: (optional) Dictionary mapping protocol or protocol and hostname to the URL of the proxy.
:param hooks: (optional) Dictionary mapping hook name to one event or list of events, event must be callable.
:param stream: (optional) whether to immediately download the response content. Defaults to ``False``.
:param verify: (optional) if ``True``, the SSL cert will be verified. A CA_BUNDLE path can also be provided.
:param verify: (optional) Either a boolean, in which case it controls whether we verify
the server's TLS certificate, or a string, in which case it must be a path
to a CA bundle to use. Defaults to ``True``. When set to
``False``, requests will accept any TLS certificate presented by
the server, and will ignore hostname mismatches and/or expired
certificates, which will make your application vulnerable to
man-in-the-middle (MitM) attacks. Setting verify to ``False``
may be useful during local development or testing.
:param cert: (optional) if String, path to ssl client cert file (.pem). If Tuple, ('cert', 'key') pair.
"""

Expand Down Expand Up @@ -187,6 +236,31 @@ def _send_request_safe_mode(self, method, url, **kwargs):
r.request = Request(method, url).prepare()
return r

# These # type: ignore[override] comments below are needed because our overridden version of functions receives
# more arguments than functions in the base class.
def get(self, url: str | bytes, **kwargs: Unpack[RESTKwargs]): # type: ignore[override]
return super().get(url, **kwargs) # type: ignore[misc]

def options(self, url: str | bytes, **kwargs: Unpack[RESTKwargs]): # type: ignore[override]
return super().options(url, **kwargs) # type: ignore[misc]

def head(self, url: str | bytes, **kwargs: Unpack[RESTKwargs]): # type: ignore[override]
return super().head(url, **kwargs) # type: ignore[misc]

# These # type: ignore[misc] comments below are needed because data and json parameters are already defined in the
# RESTKwargs TypedDict. An alternative approach is to define another TypedDict which doesn't contain them.
def post(self, url: str | bytes, data: Any | None = None, json: Any | None = None, **kwargs: Unpack[RESTKwargs]): # type: ignore[override, misc]
return super().post(url, data=data, json=json, **kwargs) # type: ignore[misc]

def put(self, url: str | bytes, data: Any | None = None, **kwargs: Unpack[RESTKwargs]): # type: ignore[override, misc]
return super().put(url, data=data, **kwargs) # type: ignore[misc]

def patch(self, url: str | bytes, data: Any | None = None, **kwargs: Unpack[RESTKwargs]): # type: ignore[override, misc]
return super().patch(url, data=data, **kwargs) # type: ignore[misc]

def delete(self, url: str | bytes, **kwargs: Unpack[RESTKwargs]): # type: ignore[override]
return super().delete(url, **kwargs) # type: ignore[misc]


class ResponseContextManager(LocustResponse):
"""
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"geventhttpclient >=2.3.1",
"ConfigArgParse >=1.5.5",
"tomli >=1.1.0; python_version<'3.11'",
"typing_extensions >=4.6.0",
"psutil >=5.9.1",
"Flask-Login >=0.6.3",
"Flask-Cors >=3.0.10",
Expand Down