From d86353e4233e65f2f358c7a04208b40e1741652e Mon Sep 17 00:00:00 2001 From: Yifei Kong Date: Sun, 17 Mar 2024 21:18:17 +0800 Subject: [PATCH] Add more typing conversions for pyright --- curl_cffi/_asyncio_selector.py | 2 +- curl_cffi/aio.py | 4 +--- curl_cffi/requests/headers.py | 13 ++++++++----- curl_cffi/requests/models.py | 19 ++++++++++++++----- curl_cffi/requests/session.py | 30 +++++++++++++++--------------- examples/stream.py | 2 +- tests/unittest/conftest.py | 2 +- tests/unittest/test_curl.py | 3 ++- tests/unittest/test_requests.py | 7 ++++++- 9 files changed, 49 insertions(+), 33 deletions(-) diff --git a/curl_cffi/_asyncio_selector.py b/curl_cffi/_asyncio_selector.py index 24a647c8..7143fea3 100644 --- a/curl_cffi/_asyncio_selector.py +++ b/curl_cffi/_asyncio_selector.py @@ -33,7 +33,7 @@ class _HasFileno(Protocol): def fileno(self) -> int: - pass + return 0 _FileDescriptorLike = Union[int, _HasFileno] diff --git a/curl_cffi/aio.py b/curl_cffi/aio.py index 6b0e5cb0..27ac009d 100644 --- a/curl_cffi/aio.py +++ b/curl_cffi/aio.py @@ -5,8 +5,6 @@ from typing import Any, Dict, Set from weakref import WeakKeyDictionary, WeakSet -import cffi - from ._wrapper import ffi, lib from .const import CurlMOpt from .curl import DEFAULT_CACERT, Curl @@ -134,7 +132,7 @@ def __init__(self, cacert: str = "", loop=None): self._curlm = lib.curl_multi_init() self._cacert = cacert or DEFAULT_CACERT self._curl2future: Dict[Curl, asyncio.Future] = {} # curl to future map - self._curl2curl: Dict[cffi.CData, Curl] = {} # c curl to Curl + self._curl2curl: Dict[ffi.CData, Curl] = {} # c curl to Curl self._sockfds: Set[int] = set() # sockfds self.loop = _get_selector(loop if loop is not None else asyncio.get_running_loop()) self._checker = self.loop.create_task(self._force_timeout()) diff --git a/curl_cffi/requests/headers.py b/curl_cffi/requests/headers.py index b4b25f58..8c451ac5 100644 --- a/curl_cffi/requests/headers.py +++ b/curl_cffi/requests/headers.py @@ -22,8 +22,8 @@ HeaderTypes = Union[ "Headers", - Mapping[str, str], - Mapping[bytes, bytes], + Mapping[str, Optional[str]], + Mapping[bytes, Optional[bytes]], Sequence[Tuple[str, str]], Sequence[Tuple[bytes, bytes]], Sequence[Union[str, bytes]], @@ -35,7 +35,7 @@ def to_str(value: Union[str, bytes], encoding: str = "utf-8") -> str: def to_bytes_or_str(value: str, match_type_of: AnyStr) -> AnyStr: - return value if isinstance(match_type_of, str) else value.encode() + return value if isinstance(match_type_of, str) else value.encode() # pyright: ignore [reportGeneralTypeIssues] SENSITIVE_HEADERS = {"authorization", "proxy-authorization"} @@ -102,7 +102,10 @@ def __init__( elif isinstance(headers, list): if isinstance(headers[0], (str, bytes)): sep = ":" if isinstance(headers[0], str) else b":" - h = [(k, v.lstrip()) for line in headers for k, v in [line.split(sep, maxsplit=1)]] + h = [] + for line in headers: + k, v = line.split(sep, maxsplit=1) # pyright: ignore + h.append((k, v.strip())) elif isinstance(headers[0], tuple): h = headers self._list = [ @@ -111,7 +114,7 @@ def __init__( normalize_header_key(k, lower=True, encoding=encoding), normalize_header_value(v, encoding), ) - for k, v in h + for k, v in h # pyright: ignore ] self._encoding = encoding diff --git a/curl_cffi/requests/models.py b/curl_cffi/requests/models.py index 0274fa2d..87670ffe 100644 --- a/curl_cffi/requests/models.py +++ b/curl_cffi/requests/models.py @@ -1,7 +1,8 @@ import queue import warnings +from concurrent.futures import Future from json import loads -from typing import Any, Dict, List, Optional +from typing import Any, Awaitable, Dict, List, Optional from .. import Curl from .cookies import Cookies @@ -65,7 +66,8 @@ def __init__(self, curl: Optional[Curl] = None, request: Optional[Request] = Non self.history: List[Dict[str, Any]] = [] self.infos: Dict[str, Any] = {} self.queue: Optional[queue.Queue] = None - self.stream_task = None + self.stream_task: Optional[Future] = None + self.astream_task: Optional[Awaitable] = None self.quit_now = None def _decode(self, content: bytes) -> str: @@ -117,6 +119,9 @@ def iter_content(self, chunk_size=None, decode_unicode=False): warnings.warn("chunk_size is ignored, there is no way to tell curl that.") if decode_unicode: raise NotImplementedError() + + assert self.queue and self.curl, "stream mode is not enabled." + while True: chunk = self.queue.get() @@ -133,11 +138,12 @@ def iter_content(self, chunk_size=None, decode_unicode=False): yield chunk def json(self, **kw): - """return a prased json object of the content.""" + """return a parsed json object of the content.""" return loads(self.content, **kw) def close(self): """Close the streaming connection, only valid in stream mode.""" + if self.quit_now: self.quit_now.set() if self.stream_task: @@ -179,6 +185,8 @@ async def aiter_content(self, chunk_size=None, decode_unicode=False): if decode_unicode: raise NotImplementedError() + assert self.queue and self.curl, "stream mode is not enabled." + while True: chunk = await self.queue.get() @@ -209,8 +217,9 @@ async def acontent(self) -> bytes: async def aclose(self): """Close the streaming connection, only valid in stream mode.""" - if self.stream_task: - await self.stream_task + + if self.astream_task: + await self.astream_task # It prints the status code of the response instead of # the object's memory location. diff --git a/curl_cffi/requests/session.py b/curl_cffi/requests/session.py index 145e53fd..5eca7aa9 100644 --- a/curl_cffi/requests/session.py +++ b/curl_cffi/requests/session.py @@ -361,8 +361,8 @@ def _set_curl_options( username, password = self.auth if auth: username, password = auth - c.setopt(CurlOpt.USERNAME, username.encode()) - c.setopt(CurlOpt.PASSWORD, password.encode()) + c.setopt(CurlOpt.USERNAME, username.encode()) # pyright: ignore [reportPossiblyUnboundVariable=none] + c.setopt(CurlOpt.PASSWORD, password.encode()) # pyright: ignore [reportPossiblyUnboundVariable=none] # timeout if timeout is not_set: @@ -813,12 +813,12 @@ def perform(): except CurlError as e: rsp = self._parse_response(c, buffer, header_buffer) rsp.request = req - q.put_nowait(RequestsError(str(e), e.code, rsp)) + cast(queue.Queue, q).put_nowait(RequestsError(str(e), e.code, rsp)) finally: - if not header_recved.is_set(): - header_recved.set() + if not cast(threading.Event, header_recved).is_set(): + cast(threading.Event, header_recved).set() # None acts as a sentinel - q.put(None) + cast(queue.Queue, q).put(None) def cleanup(fut): header_parsed.wait() @@ -828,12 +828,12 @@ def cleanup(fut): stream_task.add_done_callback(cleanup) # Wait for the first chunk - header_recved.wait() + cast(threading.Event, header_recved).wait() rsp = self._parse_response(c, buffer, header_buffer) header_parsed.set() # Raise the exception if something wrong happens when receiving the header. - first_element = _peek_queue(q) + first_element = _peek_queue(cast(queue.Queue, q)) if isinstance(first_element, RequestsError): c.reset() raise first_element @@ -1080,12 +1080,12 @@ async def perform(): except CurlError as e: rsp = self._parse_response(curl, buffer, header_buffer) rsp.request = req - q.put_nowait(RequestsError(str(e), e.code, rsp)) + cast(asyncio.Queue, q).put_nowait(RequestsError(str(e), e.code, rsp)) finally: - if not header_recved.is_set(): - header_recved.set() + if not cast(asyncio.Event, header_recved).is_set(): + cast(asyncio.Event, header_recved).set() # None acts as a sentinel - await q.put(None) + await cast(asyncio.Queue, q).put(None) def cleanup(fut): self.release_curl(curl) @@ -1093,20 +1093,20 @@ def cleanup(fut): stream_task = asyncio.create_task(perform()) stream_task.add_done_callback(cleanup) - await header_recved.wait() + await cast(asyncio.Event, header_recved).wait() # Unlike threads, coroutines does not use preemptive scheduling. # For asyncio, there is no need for a header_parsed event, the # _parse_response will execute in the foreground, no background tasks running. rsp = self._parse_response(curl, buffer, header_buffer) - first_element = _peek_aio_queue(q) + first_element = _peek_aio_queue(cast(asyncio.Queue, q)) if isinstance(first_element, RequestsError): self.release_curl(curl) raise first_element rsp.request = req - rsp.stream_task = stream_task + rsp.astream_task = stream_task rsp.quit_now = quit_now rsp.queue = q return rsp diff --git a/examples/stream.py b/examples/stream.py index f6bca366..e8ea40e5 100644 --- a/examples/stream.py +++ b/examples/stream.py @@ -5,7 +5,7 @@ try: # Python 3.10+ - from contextlib import aclosing + from contextlib import aclosing # pyright: ignore except ImportError: from contextlib import asynccontextmanager diff --git a/tests/unittest/conftest.py b/tests/unittest/conftest.py index a1844d09..4d50acf8 100644 --- a/tests/unittest/conftest.py +++ b/tests/unittest/conftest.py @@ -591,7 +591,7 @@ def __init__(self, port): def run(self): async def serve(port): # GitHub actions only likes 127, not localhost, wtf... - async with websockets.serve(echo, "127.0.0.1", port): + async with websockets.serve(echo, "127.0.0.1", port): # pyright: ignore await asyncio.Future() # run forever asyncio.run(serve(self.port)) diff --git a/tests/unittest/test_curl.py b/tests/unittest/test_curl.py index 6a71fc0d..81464c9a 100644 --- a/tests/unittest/test_curl.py +++ b/tests/unittest/test_curl.py @@ -1,6 +1,7 @@ import base64 import json from io import BytesIO +from typing import cast import pytest @@ -309,7 +310,7 @@ def test_elapsed(server): url = str(server.url) c.setopt(CurlOpt.URL, url.encode()) c.perform() - assert c.getinfo(CurlInfo.TOTAL_TIME) > 0 + assert cast(int, c.getinfo(CurlInfo.TOTAL_TIME)) > 0 def test_reason(server): diff --git a/tests/unittest/test_requests.py b/tests/unittest/test_requests.py index 95a07ee8..ec9f1138 100644 --- a/tests/unittest/test_requests.py +++ b/tests/unittest/test_requests.py @@ -8,6 +8,7 @@ from curl_cffi import CurlOpt, requests from curl_cffi.const import CurlECode, CurlInfo from curl_cffi.requests.errors import SessionClosed +from curl_cffi.requests.models import Response def test_head(server): @@ -190,6 +191,7 @@ def test_too_many_redirects(server): with pytest.raises(requests.RequestsError) as e: requests.get(str(server.url.copy_with(path="/redirect_loop")), max_redirects=2) assert e.value.code == CurlECode.TOO_MANY_REDIRECTS + assert isinstance(e.value.response, Response) assert e.value.response.status_code == 301 @@ -548,6 +550,7 @@ def test_stream_redirect_loop(server): with s.stream("GET", url, max_redirects=2): pass assert e.value.code == CurlECode.TOO_MANY_REDIRECTS + assert isinstance(e.value.response, Response) assert e.value.response.status_code == 301 @@ -559,6 +562,7 @@ def test_stream_redirect_loop_without_close(server): s.get(url, max_redirects=2, stream=True) assert e.value.code == CurlECode.TOO_MANY_REDIRECTS + assert isinstance(e.value.response, Response) assert e.value.response.status_code == 301 @@ -588,6 +592,7 @@ def test_stream_auto_close_with_header_errors(server): with pytest.raises(requests.RequestsError) as e: s.get(url, max_redirects=2, stream=True) assert e.value.code == CurlECode.TOO_MANY_REDIRECTS + assert isinstance(e.value.response, Response) assert e.value.response.status_code == 301 url = str(server.url.copy_with(path="/")) @@ -646,4 +651,4 @@ def test_curl_infos(server): r = s.get(str(server.url)) - assert r.infos[CurlInfo.PRIMARY_IP] == b"127.0.0.1" + assert r.infos[CurlInfo.PRIMARY_IP] == b"127.0.0.1" # pyright: ignore