Skip to content

Commit

Permalink
Refactor HttpClient
Browse files Browse the repository at this point in the history
  • Loading branch information
Some User committed Dec 26, 2022
1 parent ca20f4a commit 56c9292
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 74 deletions.
38 changes: 33 additions & 5 deletions grab/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,30 @@
from typing import Any, Generic, Literal, TypeVar, cast

__all__ = ["BaseRequest", "BaseExtension", "BaseClient", "BaseTransport"]

RequestT = TypeVar("RequestT", bound="BaseRequest")
ResponseT = TypeVar("ResponseT", bound="BaseResponse")
RequestDupT = TypeVar("RequestDupT", bound="BaseRequest")
ResponseDupT = TypeVar("ResponseDupT", bound="BaseResponse")
T = TypeVar("T")


def resolve_transport_entity(
entity: None
| BaseTransport[RequestT, ResponseT]
| type[BaseTransport[RequestT, ResponseT]],
default: type[BaseTransport[RequestT, ResponseT]],
) -> BaseTransport[RequestT, ResponseT]:
if entity and (
not isinstance(entity, BaseTransport) and not issubclass(entity, BaseTransport)
):
raise TypeError("Invalid BaseTransport entity: {}".format(entity))
if entity is None:
return default()
if isinstance(entity, BaseTransport):
return entity
return entity()


class BaseRequest(metaclass=ABCMeta):
init_keys: set[str] = set()

Expand Down Expand Up @@ -87,14 +103,19 @@ def __init__(self) -> None:


class BaseClient(Generic[RequestT, ResponseT], metaclass=ABCMeta):
__slots__ = ()
__slots__ = ["transport"]
transport: BaseTransport[RequestT, ResponseT]

@property
@abstractmethod
def request_class(self) -> type[RequestT]:
...

@property
@abstractmethod
def default_transport_class(self) -> type[BaseTransport[RequestT, ResponseT]]:
...

extensions: MutableMapping[str, MutableMapping[str, Any]] = {}
ext_handlers: Mapping[str, list[Callable[..., Any]]] = {
"request:pre": [],
Expand All @@ -104,7 +125,15 @@ def request_class(self) -> type[RequestT]:
"retry": [],
}

def __init__(self) -> None:
def __init__(
self,
transport: None
| BaseTransport[RequestT, ResponseT]
| type[BaseTransport[RequestT, ResponseT]] = None,
):
self.transport = resolve_transport_entity(
transport, self.default_transport_class
)
for item in self.extensions.values():
item["instance"].reset()

Expand All @@ -118,8 +147,7 @@ def request(self, req: None | RequestT = None, **request_kwargs: Any) -> Respons
retry = Retry()
all(x(retry) for x in self.ext_handlers["init-retry"])
while True:
for func in self.ext_handlers["request:pre"]:
func(req)
all(func(req) for func in self.ext_handlers["request:pre"])
self.transport.reset()
self.transport.request(req)
with self.transport.wrap_transport_error():
Expand Down
28 changes: 4 additions & 24 deletions grab/client.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,23 @@
from __future__ import annotations

import logging
from collections.abc import Mapping, MutableMapping
from copy import copy
from pprint import pprint # pylint: disable=unused-import
from typing import Any

from .base import BaseClient, BaseTransport
from .base import BaseClient
from .document import Document
from .extensions import RedirectExtension
from .request import HttpRequest
from .transport import Urllib3Transport
from .types import resolve_entity, resolve_transport_entity
from .types import resolve_entity

__all__ = ["HttpClient", "request"]
logger = logging.getLogger(__name__)


def copy_config(config: Mapping[str, Any]) -> MutableMapping[str, Any]:
"""Copy grab config with correct handling of mutable config values."""
return {x: copy(y) for x, y in config.items()}


class HttpClient(BaseClient[HttpRequest, Document]):
document_class: type[Document] = Document
transport_class = Urllib3Transport
extension = RedirectExtension()
request_class = HttpRequest

def __init__(
self,
transport: None
| BaseTransport[HttpRequest, Document]
| type[BaseTransport[HttpRequest, Document]] = None,
) -> None:
self.config: MutableMapping[str, Any] = {}
self.transport = resolve_transport_entity(transport, self.transport_class)
super().__init__()
default_transport_class = Urllib3Transport

def request(
self, req: None | str | HttpRequest = None, **request_kwargs: Any
Expand All @@ -50,8 +31,7 @@ def request(
def process_request_result(self, req: HttpRequest) -> Document:
"""Process result of real request performed via transport extension."""
doc = self.transport.prepare_response(req, document_class=self.document_class)
for func in self.ext_handlers["response:post"]:
func(req, doc)
all(func(req, doc) for func in self.ext_handlers["response:post"])
return doc


Expand Down
37 changes: 1 addition & 36 deletions grab/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,46 +10,11 @@
import typing
from typing import TypeVar, cast

from .base import BaseClient, BaseExtension, BaseTransport, RequestT, ResponseT
from .base import BaseExtension, RequestT, ResponseT

T = TypeVar("T")


def resolve_transport_entity(
entity: None
| BaseTransport[RequestT, ResponseT]
| type[BaseTransport[RequestT, ResponseT]],
default: type[BaseTransport[RequestT, ResponseT]],
) -> BaseTransport[RequestT, ResponseT]:
if entity and (
not isinstance(entity, BaseTransport) and not issubclass(entity, BaseTransport)
):
raise TypeError("Invalid BaseTransport entity: {}".format(entity))
if entity is None:
return default()
if isinstance(entity, BaseTransport):
return entity
return entity()


def resolve_grab_entity(
entity: None
| BaseClient[RequestT, ResponseT]
| type[BaseClient[RequestT, ResponseT]],
default: type[BaseClient[RequestT, ResponseT]],
) -> BaseClient[RequestT, ResponseT]:
if entity and (
not isinstance(entity, BaseClient) and not issubclass(entity, BaseClient)
):
raise TypeError("Invalid BaseClient entity: {}".format(entity))
if entity is None:
assert issubclass(default, BaseClient)
return default()
if isinstance(entity, BaseClient):
return entity
return entity()


def resolve_extension_entity(
entity: BaseExtension[RequestT, ResponseT]
| type[BaseExtension[RequestT, ResponseT]],
Expand Down
22 changes: 13 additions & 9 deletions tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,44 @@
from unittest import TestCase

from grab import HttpClient
from grab.base import resolve_transport_entity
from grab.transport import Urllib3Transport
from grab.types import resolve_grab_entity, resolve_transport_entity
from grab.types import resolve_entity


class ResolveHttpClientEntityTestCase(TestCase):
def test_resolve_grab_entity_default(self) -> None:
def test_resolve_entity_default(self) -> None:
class SuperHttpClient(HttpClient):
pass

self.assertTrue(
isinstance(resolve_grab_entity(None, SuperHttpClient), SuperHttpClient)
isinstance(
resolve_entity(HttpClient, None, SuperHttpClient), SuperHttpClient
)
)

def test_resolve_grab_entity_none_nodefault(self) -> None:
def test_resolve_entity_none_nodefault(self) -> None:
with self.assertRaises(TypeError):
resolve_grab_entity(None, None)
resolve_entity(None, None, None)

def test_resolve_grab_entity_instance(self) -> None:
def test_resolve_entity_instance(self) -> None:
class SuperHttpClient(HttpClient):
pass

self.assertTrue(
isinstance(
resolve_grab_entity(SuperHttpClient(), HttpClient), SuperHttpClient
resolve_entity(HttpClient, SuperHttpClient(), HttpClient),
SuperHttpClient,
)
)

def test_resolve_grab_entity_class(self) -> None:
def test_resolve_entity_class(self) -> None:
class SuperHttpClient(HttpClient):
pass

self.assertTrue(
isinstance(
resolve_grab_entity(SuperHttpClient, HttpClient), SuperHttpClient
resolve_entity(HttpClient, SuperHttpClient, HttpClient), SuperHttpClient
)
)

Expand Down

0 comments on commit 56c9292

Please sign in to comment.