Skip to content

Commit

Permalink
Improve types
Browse files Browse the repository at this point in the history
  • Loading branch information
Some User committed Dec 27, 2022
1 parent 99830e6 commit e0b974e
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 40 deletions.
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ ignore=F401,C812,C813,D100,D101,D102,D103,D104,D106,D107,D105,P101,PIE798,PIE786
max-line-length=88
inline-quotes = double
max-complexity=10
max-cognitive-complexity=17
max-expression-complexity=8
42 changes: 16 additions & 26 deletions grab/document.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The Document class is the result of network request made with Grab instance."""

from __future__ import annotations

import email
Expand All @@ -16,7 +17,6 @@
from copy import copy, deepcopy
from http.cookiejar import Cookie
from io import BytesIO, StringIO
from pprint import pprint # pylint: disable=unused-import
from re import Match, Pattern
from typing import Any, TypedDict, cast
from urllib.parse import SplitResult, parse_qs, urljoin, urlsplit
Expand Down Expand Up @@ -64,11 +64,11 @@ class Document(
"document_type",
"code",
"head",
"_bytes_body",
"headers",
"url",
"cookies",
"encoding",
"_bytes_body",
"_unicode_body",
"download_size",
"upload_size",
Expand All @@ -87,7 +87,7 @@ class Document(

def __init__(
self,
body: None | bytes = None,
body: bytes,
*,
document_type: None | str = "html",
head: None | bytes = None,
Expand All @@ -98,7 +98,6 @@ def __init__(
cookies: None | Sequence[Cookie] = None,
) -> None:
# Cache attributes
self._bytes_body: None | bytes = None
self._unicode_body: None | str = None
self._lxml_tree: None | _Element = None
self._strict_lxml_tree: None | _Element = None
Expand All @@ -107,13 +106,12 @@ def __init__(
self._file_fields: MutableMapping[str, Any] = {}
# Main attributes
self.document_type = document_type
if body is not None:
if not isinstance(body, bytes):
raise GrabMisuseError("Document content must be bytes")
self.set_body(body)
if not isinstance(body, bytes):
raise ValueError("Argument 'body' must be bytes")
self._bytes_body = body
self.code = code
self.head = head
self.headers = headers
self.headers: email.message.Message = headers or email.message.Message()
self.url = url
# Encoding must be processed AFTER body and headers are set
self.encoding = self.process_encoding(encoding)
Expand Down Expand Up @@ -171,7 +169,7 @@ def save(self, path: str) -> None:
os.makedirs(path_dir)

with open(path, "wb") as out:
out.write(self._bytes_body if self._bytes_body is not None else b"")
out.write(self.body)

@property
def status(self) -> None | int:
Expand Down Expand Up @@ -234,15 +232,15 @@ def text_search(self, anchor: str | bytes) -> bool:
"""
assert self.body is not None
if isinstance(anchor, str):
return anchor in cast(str, self.unicode_body())
return anchor in self.unicode_body()
return anchor in self.body

def text_assert(self, anchor: str | bytes) -> None:
"""If `anchor` is not found then raise `DataNotFound` exception."""
if not self.text_search(anchor):
raise DataNotFound("Substring not found: {}".format(str(anchor)))

def text_assert_any(self, anchors: list[str | bytes]) -> None:
def text_assert_any(self, anchors: Sequence[str | bytes]) -> None:
"""If no `anchors` were found then raise `DataNotFound` exception."""
if not any(self.text_search(x) for x in anchors):
raise DataNotFound(
Expand Down Expand Up @@ -284,7 +282,7 @@ def rex_search(
match = (
regexp.search(self.body)
if isinstance(regexp.pattern, bytes)
else regexp.search(cast(str, self.unicode_body()))
else regexp.search(self.unicode_body())
)
if match:
return match
Expand Down Expand Up @@ -317,17 +315,13 @@ def pyquery(self) -> Any:

# BodyExtension methods

def get_body_chunk(self) -> None | bytes:
if self._bytes_body:
return self._bytes_body[:4096]
return None
def get_body_chunk(self) -> bytes:
return self.body[:4096]

def unicode_body(
self,
) -> None | str: # , ignore_errors: bool = True) -> None | str:
) -> str:
"""Return response body as unicode string."""
if self.body is None:
return None
if not self._unicode_body:
# FIXME: ignore_errors option
self._unicode_body = unicodec.decode_content(
Expand All @@ -336,17 +330,13 @@ def unicode_body(
return self._unicode_body

@property
def body(self) -> None | bytes:
return cast(bytes, self._bytes_body)
def body(self) -> bytes:
return self._bytes_body

@body.setter
def body(self, body: bytes) -> None:
raise GrabMisuseError("Document body could be set only in constructor")

def set_body(self, body: bytes) -> None:
self._bytes_body = body
self._unicode_body = None

# DomTreeExtension methods

@property
Expand Down
38 changes: 31 additions & 7 deletions tests/test_ext_lxml.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from lxml.html import fromstring
from __future__ import annotations

import typing
from typing import cast

from lxml.etree import _Element
from lxml.html import HtmlElement, fromstring
from test_server import Response

from grab import DataNotFound, request
Expand Down Expand Up @@ -31,7 +37,6 @@
</ul>
"""


XML = b"""
<root>
<man>
Expand Down Expand Up @@ -64,20 +69,36 @@ def setUp(self) -> None:
def test_lxml_text_content_fail(self) -> None:
# lxml node text_content() method do not put spaces between text
# content of adjacent XML nodes
# pylint: disable=deprecated-typing-alias
self.assertEqual(
self.lxml_tree.xpath('//div[@id="bee"]/div')[0].text_content().strip(),
cast(
typing.List[HtmlElement], self.lxml_tree.xpath('//div[@id="bee"]/div')
)[0]
.text_content()
.strip(),
"пчела",
)
self.assertEqual(
self.lxml_tree.xpath('//div[@id="fly"]')[0].text_content().strip(),
cast(typing.List[HtmlElement], self.lxml_tree.xpath('//div[@id="fly"]'))[0]
.text_content()
.strip(),
"му\nха",
)

def test_lxml_xpath(self) -> None:
names = {x.tag for x in self.lxml_tree.xpath('//div[@id="bee"]//*')}
# pylint: disable=deprecated-typing-alias
names = {
x.tag
for x in cast(
typing.List[_Element], self.lxml_tree.xpath('//div[@id="bee"]//*')
)
}
self.assertEqual({"em", "div", "strong", "style", "script"}, names)
xpath_query = '//div[@id="bee"]//*[name() != "script" and name() != "style"]'
names = {x.tag for x in self.lxml_tree.xpath(xpath_query)}
names = {
x.tag
for x in cast(typing.List[_Element], self.lxml_tree.xpath(xpath_query))
}
self.assertEqual({"em", "div", "strong"}, names)

def test_xpath(self) -> None:
Expand Down Expand Up @@ -136,9 +157,12 @@ def test_xpath_exists(self) -> None:
self.assertFalse(self.doc.select('//li[@id="num-3"]').exists())

def test_cdata_issue(self) -> None:
# pylint: disable=deprecated-typing-alias
self.server.add_response(Response(data=XML), count=2)
doc = request(self.server.get_url(), document_type="xml")
self.assertEqual("30", doc.tree.xpath("//weight")[0].text)
self.assertEqual(
"30", cast(typing.List[_Element], doc.tree.xpath("//weight"))[0].text
)

def test_xml_declaration(self) -> None:
# HTML with XML declaration should be processed without errors.
Expand Down
5 changes: 1 addition & 4 deletions tests/test_grab_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from pprint import pprint # pylint: disable=unused-import

from test_server import Response

from grab import HttpClient
from grab.document import Document
from grab.errors import GrabMisuseError
from tests.util import BaseTestCase


Expand Down Expand Up @@ -46,4 +43,4 @@ def test_document_invalid_input(self) -> None:
data = """
<h1>test</h1>
"""
self.assertRaises(GrabMisuseError, Document, data)
self.assertRaises(ValueError, Document, data)
2 changes: 1 addition & 1 deletion tests/test_grab_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class DummyDocument(Document):

class DummyHttpClient(HttpClient):
def request(self, *_args, **_kwargs):
return DummyDocument()
return DummyDocument(b"")

self.assertTrue(
isinstance(
Expand Down
7 changes: 5 additions & 2 deletions tests/test_grab_proxy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

# from proxylist.source import BaseFileProxySource
# from proxylist import ProxyList
from test_server import Response, TestServer

from grab import request
from tests.util import BaseTestCase # , temp_file

# from proxylist.source import BaseFileProxySource
# from proxylist import ProxyList


TestServer.__test__ = False # make pytest do not explore it for test cases
ADDRESS = "127.0.0.1"

Expand All @@ -21,6 +23,7 @@ def setUpClass(cls) -> None:
for _ in range(3):
serv = TestServer(address=ADDRESS, port=0)
serv.start()
assert serv.port is not None
cls.extra_servers[serv.port] = {
"server": serv,
"proxy": "%s:%d" % (ADDRESS, serv.port),
Expand Down

0 comments on commit e0b974e

Please sign in to comment.