Skip to content

Commit

Permalink
Move most of HttpClient:request logic to BaseClient:request
Browse files Browse the repository at this point in the history
  • Loading branch information
Some User committed Dec 26, 2022
1 parent a4123e8 commit ca20f4a
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 58 deletions.
41 changes: 38 additions & 3 deletions grab/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import typing
from abc import ABCMeta, abstractmethod
from collections.abc import Callable, Generator, Mapping, MutableMapping
from contextlib import contextmanager
from copy import deepcopy
from http.cookiejar import CookieJar
from typing import Any, Generic, Literal, TypeVar
from typing import Any, Generic, Literal, TypeVar, cast

__all__ = ["BaseRequest", "BaseExtension", "BaseClient", "BaseTransport"]

Expand Down Expand Up @@ -80,8 +81,19 @@ def reset(self) -> None:
...


class Retry:
def __init__(self) -> None:
self.state: MutableMapping[str, int] = {}


class BaseClient(Generic[RequestT, ResponseT], metaclass=ABCMeta):
__slots__ = ()
transport: BaseTransport[RequestT, ResponseT]

@property
@abstractmethod
def request_class(self) -> type[RequestT]:
...

extensions: MutableMapping[str, MutableMapping[str, Any]] = {}
ext_handlers: Mapping[str, list[Callable[..., Any]]] = {
Expand All @@ -97,9 +109,32 @@ def __init__(self) -> None:
item["instance"].reset()

@abstractmethod
def request(self, req: None | RequestT = None, **request_kwargs: Any) -> ResponseT:
def process_request_result(self, req: RequestT) -> ResponseT:
...

def request(self, req: None | RequestT = None, **request_kwargs: Any) -> ResponseT:
if req is None:
req = self.request_class.create_from_mapping(request_kwargs)
retry = Retry()
all(x(retry) for x in self.ext_handlers["init-retry"])
while True:
for func in self.ext_handlers["request:pre"]:
func(req)
self.transport.reset()
self.transport.request(req)
with self.transport.wrap_transport_error():
doc = self.process_request_result(req)
if any(
(
(item := func(retry, req, doc)) != (None, None)
for func in self.ext_handlers["retry"]
)
):
# pylint: disable=deprecated-typing-alias
retry, req = cast(typing.Tuple[Retry, RequestT], item)
continue
return doc

def clone(self: T) -> T:
return deepcopy(self)

Expand All @@ -123,5 +158,5 @@ def wrap_transport_error(self) -> Generator[None, None, None]: # pragma: no cov
raise NotImplementedError

@abstractmethod
def request(self, req: RequestT, cookiejar: CookieJar) -> None: # pragma: no cover
def request(self, req: RequestT) -> None: # pragma: no cover
raise NotImplementedError
45 changes: 7 additions & 38 deletions grab/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from __future__ import annotations

import logging
import typing
from collections.abc import Mapping, MutableMapping
from copy import copy
from http.cookiejar import CookieJar
from pprint import pprint # pylint: disable=unused-import
from typing import Any, cast
from typing import Any

from .base import BaseClient, BaseTransport
from .document import Document
Expand All @@ -24,15 +22,11 @@ def copy_config(config: Mapping[str, Any]) -> MutableMapping[str, Any]:
return {x: copy(y) for x, y in config.items()}


class Retry:
def __init__(self) -> None:
self.state: MutableMapping[str, int] = {}


class HttpClient(BaseClient[HttpRequest, Document]):
document_class: type[Document] = Document
transport_class = Urllib3Transport
extension = RedirectExtension()
request_class = HttpRequest

def __init__(
self,
Expand All @@ -44,39 +38,14 @@ def __init__(
self.transport = resolve_transport_entity(transport, self.transport_class)
super().__init__()

def get_request_cookies(self, req: HttpRequest) -> CookieJar:
jar = CookieJar()
for func in self.ext_handlers["request_cookies"]:
func(req, jar)
return jar

def request(
self, req: None | str | HttpRequest = None, **request_kwargs: Any
) -> Document:
if not isinstance(req, HttpRequest):
if req is not None:
assert isinstance(req, str)
request_kwargs["url"] = req
req = HttpRequest.create_from_mapping(request_kwargs)
retry = Retry()
all(x(retry) for x in self.ext_handlers["init-retry"])
while True:
for func in self.ext_handlers["request:pre"]:
func(req)
self.transport.reset()
self.transport.request(req, self.get_request_cookies(req))
with self.transport.wrap_transport_error():
doc = self.process_request_result(req)
if any(
(
(item := func(retry, req, doc)) != (None, None)
for func in self.ext_handlers["retry"]
)
):
# pylint: disable=deprecated-typing-alias
retry, req = cast(typing.Tuple[Retry, HttpRequest], item)
continue
return doc
if req is not None and not isinstance(req, HttpRequest):
assert isinstance(req, str)
request_kwargs["url"] = req
req = None
return super().request(req, **request_kwargs)

def process_request_result(self, req: HttpRequest) -> Document:
"""Process result of real request performed via transport extension."""
Expand Down
16 changes: 8 additions & 8 deletions grab/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .document import Document
from .errors import GrabTooManyRedirectsError
from .request import HttpRequest
from .util.cookies import build_jar, create_cookie
from .util.cookies import build_cookie_header, build_jar, create_cookie


class RedirectExtension(BaseExtension[HttpRequest, Document]):
Expand Down Expand Up @@ -53,7 +53,6 @@ def __init__(self, cookiejar: None | CookieJar = None) -> None:
self.cookiejar = cookiejar if cookiejar else CookieJar()
self.ext_handlers = {
"request:pre": self.process_request_pre,
"request_cookies": self.process_request_cookies,
"response:post": self.process_response_post,
}

Expand Down Expand Up @@ -95,12 +94,13 @@ def __setstate__(self, state: Mapping[str, Any]) -> None:

def process_request_pre(self, req: HttpRequest) -> None:
self.update(req.cookies, req.url)

def process_request_cookies(
self, req: HttpRequest, jar: CookieJar # pylint: disable=unused-argument
) -> None:
for cookie in self.cookiejar:
jar.set_cookie(cookie)
if hdr := build_cookie_header(self.cookiejar, req.url, req.headers):
if req.headers.get("Cookie"):
raise ValueError(
"Could not configure request with session cookies"
" because it has already Cookie header"
)
req.cookie_header = hdr

def process_response_post(
self, req: HttpRequest, doc: Document # pylint: disable=unused-argument
Expand Down
9 changes: 3 additions & 6 deletions grab/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

from collections.abc import Mapping, MutableMapping
from copy import copy
from http.cookiejar import CookieJar
from typing import Any, TypedDict, cast
from urllib.parse import urlencode

from urllib3.filepost import encode_multipart_formdata

from .base import BaseRequest
from .util.cookies import build_cookie_header
from .util.structures import merge_with_dict
from .util.timeout import Timeout

Expand Down Expand Up @@ -107,6 +105,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
self.multipart = multipart if multipart is not None else True
self.document_type = document_type
self.meta = meta or {}
self.cookie_header: None | str = None

def get_full_url(self) -> str:
return self.url
Expand All @@ -120,7 +119,6 @@ def _process_timeout_param(self, value: None | float | Timeout) -> Timeout:

def compile_request_data( # noqa: CCR001
self,
cookiejar: CookieJar,
) -> CompiledRequestData:
req_url = self.url
req_hdr = copy(self.headers)
Expand Down Expand Up @@ -155,9 +153,8 @@ def compile_request_data( # noqa: CCR001
{"Content-Type": content_type, "Content-Length": len(req_body)},
replace=True,
)
cookie_hdr = build_cookie_header(cookiejar, self.url, req_hdr)
if cookie_hdr:
req_hdr["Cookie"] = cookie_hdr
if self.cookie_header:
req_hdr["Cookie"] = self.cookie_header
return {
"method": self.method,
"url": req_url,
Expand Down
5 changes: 2 additions & 3 deletions grab/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from collections.abc import Generator, Mapping
from contextlib import contextmanager
from http.client import HTTPResponse
from http.cookiejar import CookieJar
from pprint import pprint # pylint: disable=unused-import
from typing import Any, cast

Expand Down Expand Up @@ -113,7 +112,7 @@ def log_request(self, req: HttpRequest) -> None:
)
LOG.debug("%s %s%s", req.method or "GET", req.url, proxy_info)

def request(self, req: HttpRequest, cookiejar: CookieJar) -> None:
def request(self, req: HttpRequest) -> None:
pool: PoolManager | SOCKSProxyManager | ProxyManager = (
self.select_pool_for_request(req)
)
Expand All @@ -135,7 +134,7 @@ def request(self, req: HttpRequest, cookiejar: CookieJar) -> None:
# It is the timeout on read of next data chunk from the server
# Total response timeout is handled by Grab
timeout = Timeout(connect=req.timeout.connect, read=req.timeout.read)
req_data = req.compile_request_data(cookiejar)
req_data = req.compile_request_data()
try:
start_time = time.time()
res = pool.urlopen( # type: ignore # FIXME
Expand Down

0 comments on commit ca20f4a

Please sign in to comment.