From 92958a219dddac046b27e18c905069b87bb0e30f Mon Sep 17 00:00:00 2001 From: Raphael Krupinski <10319569-mattesilver@users.noreply.gitlab.com> Date: Tue, 6 Feb 2024 11:59:51 +0100 Subject: [PATCH] Handle cookies on redirect. --- httpx/_client.py | 7 +- httpx/_models.py | 45 ++++++++++-- tests/client/test_redirects.py | 129 +++++++++++++++++++++++++++++++++ tests/models/test_cookies.py | 40 ++++++++++ 4 files changed, 212 insertions(+), 9 deletions(-) diff --git a/httpx/_client.py b/httpx/_client.py index e2c6702e0c..fb390929d0 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -344,7 +344,8 @@ def build_request( """ url = self._merge_url(url) headers = self._merge_headers(headers) - cookies = self._merge_cookies(cookies) + user_cookies = Cookies.for_url(url, cookies) + cookies = self._merge_cookies(user_cookies) params = self._merge_queryparams(params) extensions = {} if extensions is None else extensions if "timeout" not in extensions: @@ -364,6 +365,7 @@ def build_request( params=params, headers=headers, cookies=cookies, + user_cookies=user_cookies, extensions=extensions, ) @@ -460,12 +462,13 @@ def _build_redirect_request(self, request: Request, response: Response) -> Reque url = self._redirect_url(request, response) headers = self._redirect_headers(request, url, method) stream = self._redirect_stream(request, method) - cookies = Cookies(self.cookies) + cookies = self._merge_cookies(request.user_cookies) return Request( method=method, url=url, headers=headers, cookies=cookies, + user_cookies=request.user_cookies, stream=stream, extensions=request.extensions, ) diff --git a/httpx/_models.py b/httpx/_models.py index cd76705f1a..b84156233e 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -314,6 +314,7 @@ def __init__( params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, + user_cookies: Cookies | None = None, content: RequestContent | None = None, data: RequestData | None = None, files: RequestFiles | None = None, @@ -332,6 +333,12 @@ def __init__( self.headers = Headers(headers) self.extensions = {} if extensions is None else extensions + # Original cookies passed by the client code, extended with domain. + # Used by follow-up requests, when follow_redirects == True + self.user_cookies = ( + Cookies.for_url(self.url, cookies) if user_cookies is None else user_cookies + ) + if cookies: Cookies(cookies).set_cookie_header(self) @@ -434,7 +441,7 @@ def __getstate__(self) -> dict[str, typing.Any]: return { name: value for name, value in self.__dict__.items() - if name not in ["extensions", "stream"] + if name not in ["extensions", "stream", "user_cookies"] } def __setstate__(self, state: dict[str, typing.Any]) -> None: @@ -1030,6 +1037,21 @@ def __init__(self, cookies: CookieTypes | None = None) -> None: else: self.jar = cookies + @classmethod + def for_url(cls, url: URL, cookies: CookieTypes | None = None) -> "Cookies": + if cookies is None or isinstance(cookies, (Cookies, CookieJar)): + return cls(cookies) + if isinstance(cookies, Mapping): + cookies = cookies.items() # type: ignore + + domain = url.host + secure = url.scheme == "https" + cookies_obj = Cookies() + + for name, value in cookies: # type: ignore + cookies_obj.set(name, value, domain, secure=secure) + return cookies_obj + def extract_cookies(self, response: Response) -> None: """ Loads any cookies based on the response `Set-Cookie` headers. @@ -1046,7 +1068,14 @@ def set_cookie_header(self, request: Request) -> None: urllib_request = self._CookieCompatRequest(request) self.jar.add_cookie_header(urllib_request) - def set(self, name: str, value: str, domain: str = "", path: str = "/") -> None: + def set( + self, + name: str, + value: str, + domain: str = "", + path: str = "/", + secure: bool = False, + ) -> None: """ Set a cookie value by name. May optionally include domain and path. """ @@ -1061,7 +1090,7 @@ def set(self, name: str, value: str, domain: str = "", path: str = "/") -> None: "domain_initial_dot": domain.startswith("."), "path": path, "path_specified": bool(path), - "secure": False, + "secure": secure, "expires": None, "discard": True, "comment": None, @@ -1078,6 +1107,7 @@ def get( # type: ignore default: str | None = None, domain: str | None = None, path: str | None = None, + secure: bool | None = None, ) -> str | None: """ Get a cookie by name. May optionally include domain and path @@ -1088,10 +1118,11 @@ def get( # type: ignore if cookie.name == name: if domain is None or cookie.domain == domain: if path is None or cookie.path == path: - if value is not None: - message = f"Multiple cookies exist with name={name}" - raise CookieConflict(message) - value = cookie.value + if secure is None or cookie.secure == secure: + if value is not None: + message = f"Multiple cookies exist with name={name}" + raise CookieConflict(message) + value = cookie.value if value is None: return default diff --git a/tests/client/test_redirects.py b/tests/client/test_redirects.py index f65827134c..c8d8095700 100644 --- a/tests/client/test_redirects.py +++ b/tests/client/test_redirects.py @@ -445,3 +445,132 @@ async def test_async_invalid_redirect(): await client.get( "http://example.org/invalid_redirect", follow_redirects=True ) + + +def cookies_redirects(request: httpx.Request) -> httpx.Response: + if request.url.scheme not in ("http", "https"): + raise httpx.UnsupportedProtocol(f"Scheme {request.url.scheme!r} not supported.") + + if request.url.path == "/redir_echo": + status_code = httpx.codes.MOVED_PERMANENTLY + headers = {"location": "https://example.com/echo"} + return httpx.Response(status_code, headers=headers) + + if request.url.path == "/redir_double": + status_code = httpx.codes.MOVED_PERMANENTLY + headers = {"location": "https://not-example.com/redir_echo"} + return httpx.Response(status_code, headers=headers) + + elif request.url.path == "/redir_other": + status_code = httpx.codes.MOVED_PERMANENTLY + headers = {"location": "https://not-example.com/echo"} + return httpx.Response(status_code, headers=headers) + + elif request.url.path == "/redir_http": + status_code = httpx.codes.MOVED_PERMANENTLY + headers = {"location": "http://example.com/echo"} + return httpx.Response(status_code, headers=headers) + + elif request.url.path == "/echo": + data = {"cookies": request.headers.get("cookie")} + return httpx.Response(200, json=data) + + return httpx.Response(404, html="Not found!") + + +def test_cookies_dont_cross_domain_on_redirect(): + cookies = httpx.Cookies() + cookies.set("with_domain", "example-value", domain="example.com") + + client = httpx.Client( + transport=httpx.MockTransport(cookies_redirects), + follow_redirects=True, + cookies=cookies, + ) + + response = client.get("http://example.com/redir_other") + assert response.status_code == 200 + assert response.json() == {"cookies": None} + + +def test_dict_cookies_dont_cross_domain_on_redirect(): + cookies = { + "with_domain": "example-value", + } + + client = httpx.Client( + transport=httpx.MockTransport(cookies_redirects), + follow_redirects=True, + ) + + with pytest.warns(DeprecationWarning): + response = client.get("http://example.com/redir_other", cookies=cookies) + assert response.status_code == 200 + assert response.json() == {"cookies": None} + + +def test_dict_cookies_follow_redirect(): + cookies = { + "with_domain": "example-value", + } + + client = httpx.Client( + transport=httpx.MockTransport(cookies_redirects), + follow_redirects=True, + ) + + with pytest.warns(DeprecationWarning): + response = client.get("http://example.com/redir_echo", cookies=cookies) + assert response.status_code == 200 + assert response.json() == {"cookies": "with_domain=example-value"} + + +def test_request_cookies_dont_cross_domain_on_redirect(): + cookies = httpx.Cookies() + cookies.set("with_domain", "example-value", domain="example.com") + + client = httpx.Client( + transport=httpx.MockTransport(cookies_redirects), + follow_redirects=True, + ) + + with pytest.warns(DeprecationWarning): + response = client.get( + "http://example.com/redir_other", + cookies=cookies, + ) + assert response.status_code == 200 + assert response.json() == {"cookies": None} + + +def test_request_cookies_follow_double_redirect_across_hosts(): + cookies = { + "with_domain": "example-value", + } + + with httpx.Client( + transport=httpx.MockTransport(cookies_redirects), follow_redirects=True + ) as client: + with pytest.warns(DeprecationWarning): + response = client.get("http://example.com/redir_double", cookies=cookies) + + assert response.status_code == 200 + assert response.json() == {"cookies": "with_domain=example-value"} + + intermediate_response = response.history[1] + assert "Cookie" not in intermediate_response.request.headers + + +def test_request_cookies_dont_follow_on_http_downgrade(): + cookies = { + "with_domain": "example-value", + } + + with httpx.Client( + transport=httpx.MockTransport(cookies_redirects), follow_redirects=True + ) as client: + with pytest.warns(DeprecationWarning): + response = client.get("https://example.com/redir_http", cookies=cookies) + + assert response.status_code == 200 + assert response.json() == {"cookies": None} diff --git a/tests/models/test_cookies.py b/tests/models/test_cookies.py index f7abe11ad4..b590d5ae57 100644 --- a/tests/models/test_cookies.py +++ b/tests/models/test_cookies.py @@ -52,6 +52,46 @@ def test_cookies_with_domain_and_path(): assert len(cookies) == 0 +def test_cookies_for_url_cookies(): + cookies = httpx.Cookies() + cookies.set("name", "value") + assert httpx.Cookies.for_url(httpx.URL("http://example.com/"), cookies) == cookies + + +def test_cookies_for_url_cookiejar(): + cookies = httpx.Cookies() + cookies.set("name", "value") + assert ( + httpx.Cookies.for_url(httpx.URL("http://example.com/"), cookies.jar) == cookies + ) + + +def test_cookies_for_url_http_dict(): + cookies = httpx.Cookies.for_url(httpx.URL("http://example.com/"), {"name": "value"}) + assert cookies.get("name", domain="example.com") == "value" + + +def test_cookies_for_url_http_list(): + cookies = httpx.Cookies.for_url( + httpx.URL("http://example.com/"), [("name", "value")] + ) + assert cookies.get("name", domain="example.com") == "value" + + +def test_cookies_for_url_https_dict(): + cookies = httpx.Cookies.for_url( + httpx.URL("https://example.com/"), {"name": "value"} + ) + assert cookies.get("name", domain="example.com", secure=True) == "value" + + +def test_cookies_for_url_https_list(): + cookies = httpx.Cookies.for_url( + httpx.URL("https://example.com/"), [("name", "value")] + ) + assert cookies.get("name", domain="example.com", secure=True) == "value" + + def test_multiple_set_cookie(): jar = http.cookiejar.CookieJar() headers = [