Skip to content

Commit

Permalink
Clean up grab.base module
Browse files Browse the repository at this point in the history
  • Loading branch information
Some User committed Dec 26, 2022
1 parent 56c9292 commit 27508c6
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 47 deletions.
64 changes: 22 additions & 42 deletions grab/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from collections.abc import Callable, Generator, Mapping, MutableMapping
from contextlib import contextmanager
from copy import deepcopy
from http.cookiejar import CookieJar
from typing import Any, Generic, Literal, TypeVar, cast
from typing import Any, Generic, TypeVar, cast

__all__ = ["BaseRequest", "BaseExtension", "BaseClient", "BaseTransport"]
RequestT = TypeVar("RequestT", bound="BaseRequest")
Expand All @@ -16,23 +15,6 @@
T = TypeVar("T")


def resolve_transport_entity(
entity: None
| BaseTransport[RequestT, ResponseT]
| type[BaseTransport[RequestT, ResponseT]],
default: type[BaseTransport[RequestT, ResponseT]],
) -> BaseTransport[RequestT, ResponseT]:
if entity and (
not isinstance(entity, BaseTransport) and not issubclass(entity, BaseTransport)
):
raise TypeError("Invalid BaseTransport entity: {}".format(entity))
if entity is None:
return default()
if isinstance(entity, BaseTransport):
return entity
return entity()


class BaseRequest(metaclass=ABCMeta):
init_keys: set[str] = set()

Expand Down Expand Up @@ -61,15 +43,7 @@ class BaseResponse:


class BaseExtension(Generic[RequestT, ResponseT], metaclass=ABCMeta):
ext_handlers: Mapping[
Literal["request:pre"]
| Literal["request_cookies"]
| Literal["response:post"]
| Literal["init-retry"]
| Literal["retry"],
Callable[..., Any],
] = {}

ext_handlers: Mapping[str, Callable[..., Any]] = {}
__slots__ = ()

def __set_name__(self, owner: BaseClient[RequestT, ResponseT], attr: str) -> None:
Expand All @@ -79,19 +53,6 @@ def __set_name__(self, owner: BaseClient[RequestT, ResponseT], attr: str) -> Non
for point_name, func in self.ext_handlers.items():
owner.ext_handlers[point_name].append(func)

def process_prepare_request_post(self, req: RequestT) -> None:
pass

def process_request_cookies(
self, req: RequestT, jar: CookieJar # pylint: disable=unused-argument
) -> None:
pass

def process_response_post(
self, req: RequestT, doc: ResponseT # pylint: disable=unused-argument
) -> None:
pass

@abstractmethod
def reset(self) -> None:
...
Expand Down Expand Up @@ -131,7 +92,7 @@ def __init__(
| BaseTransport[RequestT, ResponseT]
| type[BaseTransport[RequestT, ResponseT]] = None,
):
self.transport = resolve_transport_entity(
self.transport = self.default_transport_class.resolve_entity(
transport, self.default_transport_class
)
for item in self.extensions.values():
Expand Down Expand Up @@ -188,3 +149,22 @@ def wrap_transport_error(self) -> Generator[None, None, None]: # pragma: no cov
@abstractmethod
def request(self, req: RequestT) -> None: # pragma: no cover
raise NotImplementedError

@classmethod
def resolve_entity(
cls,
entity: None
| BaseTransport[RequestT, ResponseT]
| type[BaseTransport[RequestT, ResponseT]],
default: type[BaseTransport[RequestT, ResponseT]],
) -> BaseTransport[RequestT, ResponseT]:
if entity and (
not isinstance(entity, BaseTransport)
and not issubclass(entity, BaseTransport)
):
raise TypeError("Invalid BaseTransport entity: {}".format(entity))
if entity is None:
return default()
if isinstance(entity, BaseTransport):
return entity
return entity()
13 changes: 8 additions & 5 deletions tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from unittest import TestCase

from grab import HttpClient
from grab.base import resolve_transport_entity
from grab.transport import Urllib3Transport
from grab.types import resolve_entity

Expand Down Expand Up @@ -49,20 +48,23 @@ class SuperTransport(Urllib3Transport):
pass

self.assertTrue(
isinstance(resolve_transport_entity(None, SuperTransport), SuperTransport)
isinstance(
Urllib3Transport.resolve_entity(None, SuperTransport), SuperTransport
)
)

def test_resolve_transport_entity_none_nodefault(self) -> None:
with self.assertRaises(TypeError):
resolve_transport_entity(None, None)
Urllib3Transport.resolve_entity(None, None)

def test_resolve_transport_entity_instance(self) -> None:
class SuperTransport(Urllib3Transport):
pass

self.assertTrue(
isinstance(
resolve_transport_entity(SuperTransport(), HttpClient), SuperTransport
Urllib3Transport.resolve_entity(SuperTransport(), HttpClient),
SuperTransport,
)
)

Expand All @@ -72,6 +74,7 @@ class SuperTransport(Urllib3Transport):

self.assertTrue(
isinstance(
resolve_transport_entity(SuperTransport, HttpClient), SuperTransport
Urllib3Transport.resolve_entity(SuperTransport, HttpClient),
SuperTransport,
)
)

0 comments on commit 27508c6

Please sign in to comment.