Skip to content

Commit

Permalink
Merge pull request psf#136 from oldani/bugfix/html_sessions
Browse files Browse the repository at this point in the history
Bugfix/html sessions
  • Loading branch information
kennethreitz committed Mar 20, 2018
2 parents 0f05293 + 9b21faf commit 2ef3d41
Showing 1 changed file with 55 additions and 55 deletions.
110 changes: 55 additions & 55 deletions requests_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ class BaseParser:
"""

def __init__(self, *, element, session: 'HTTPSession' = None, default_encoding: _DefaultEncoding = None, html: _HTML = None, url: _URL) -> None:
def __init__(self, *, element, default_encoding: _DefaultEncoding = None, html: _HTML = None, url: _URL) -> None:
self.element = element
self.session = session or HTMLSession()
self.url = url
self.skip_anchors = True
self.default_encoding = default_encoding
Expand Down Expand Up @@ -166,47 +165,6 @@ def full_text(self) -> _Text:
"""
return self.lxml.text_content()

def next(self, fetch: bool = False, next_symbol: _NextSymbol = DEFAULT_NEXT_SYMBOL) -> _Next:
"""Attempts to find the next page, if there is one. If ``fetch``
is ``True`` (default), returns :class:`HTML <HTML>` object of
next page. If ``fetch`` is ``False``, simply returns the next URL.
"""

def get_next():
candidates = self.find('a', containing=next_symbol)

for candidate in candidates:
if candidate.attrs.get('href'):
# Support 'next' rel (e.g. reddit).
if 'next' in candidate.attrs.get('rel', []):
return candidate.attrs['href']

# Support 'next' in classnames.
for _class in candidate.attrs.get('class', []):
if 'next' in _class:
return candidate.attrs['href']

if 'page' in candidate.attrs['href']:
return candidate.attrs['href']

try:
# Resort to the last candidate.
return candidates[-1].attrs['href']
except IndexError:
return None

next = get_next()
if next:
url = self._make_absolute(next)
else:
return None

if fetch:
return self.session.get(url)
else:
return url

def find(self, selector: str = "*", *, containing: _Containing = None, clean: bool = False, first: bool = False, _encoding: str = None) -> _Find:
"""Given a CSS Selector, returns a list of
:class:`Element <Element>` objects or a single one.
Expand Down Expand Up @@ -438,7 +396,7 @@ class HTML(BaseParser):
:param default_encoding: Which encoding to default to.
"""

def __init__(self, *, url: str = DEFAULT_URL, html: _HTML, default_encoding: str = DEFAULT_ENCODING) -> None:
def __init__(self, *, session: Union['HTTPSession', 'AsyncHTMLSession'] = None, url: str = DEFAULT_URL, html: _HTML, default_encoding: str = DEFAULT_ENCODING) -> None:

# Convert incoming unicode HTML into bytes.
if isinstance(html, str):
Expand All @@ -451,25 +409,67 @@ def __init__(self, *, url: str = DEFAULT_URL, html: _HTML, default_encoding: str
url=url,
default_encoding=default_encoding
)
self.session = session or HTMLSession()
self.page = None
self.next_symbol = DEFAULT_NEXT_SYMBOL

def __repr__(self) -> str:
return f"<HTML url={self.url!r}>"

def _next(self, fetch: bool = False, next_symbol: _NextSymbol = DEFAULT_NEXT_SYMBOL) -> _Next:
"""Attempts to find the next page, if there is one. If ``fetch``
is ``True`` (default), returns :class:`HTML <HTML>` object of
next page. If ``fetch`` is ``False``, simply returns the next URL.
"""

def get_next():
candidates = self.find('a', containing=next_symbol)

for candidate in candidates:
if candidate.attrs.get('href'):
# Support 'next' rel (e.g. reddit).
if 'next' in candidate.attrs.get('rel', []):
return candidate.attrs['href']

# Support 'next' in classnames.
for _class in candidate.attrs.get('class', []):
if 'next' in _class:
return candidate.attrs['href']

if 'page' in candidate.attrs['href']:
return candidate.attrs['href']

try:
# Resort to the last candidate.
return candidates[-1].attrs['href']
except IndexError:
return None

__next = get_next()
if __next:
url = self._make_absolute(__next)
else:
return None

if fetch:
return self.session.get(url)
else:
return url

def __iter__(self):

next = self

while True:
yield next
try:
next = next.next(fetch=True, next_symbol=self.next_symbol).html
next = next._next(fetch=True, next_symbol=self.next_symbol).html
except AttributeError:
break

def __next__(self):
return self.next(fetch=True, next_symbol=self.next_symbol).html
return self._next(fetch=True, next_symbol=self.next_symbol).html

def add_next_symbol(self, next_symbol):
self.next_symbol.append(next_symbol)
Expand Down Expand Up @@ -577,20 +577,21 @@ class HTMLResponse(requests.Response):
Effectively the same, but with an intelligent ``.html`` property added.
"""

def __init__(self) -> None:
def __init__(self, session: Union['HTMLSession', 'AsyncHTMLSession']) -> None:
super(HTMLResponse, self).__init__()
self._html = None # type: HTML
self.session = session

@property
def html(self) -> HTML:
if not self._html:
self._html = HTML(url=self.url, html=self.content, default_encoding=self.encoding)
self._html = HTML(session=self.session, url=self.url, html=self.content, default_encoding=self.encoding)

return self._html

@classmethod
def _from_response(cls, response):
html_r = cls()
def _from_response(cls, response, session: Union['HTMLSession', 'AsyncHTMLSession']):
html_r = cls(session=session)
html_r.__dict__.update(response.__dict__)
return html_r

Expand Down Expand Up @@ -647,7 +648,7 @@ def request(self, *args, **kwargs) -> HTMLResponse:
# Convert Request object into HTTPRequest object.
r = super(HTMLSession, self).request(*args, **kwargs)

return HTMLResponse._from_response(r)
return HTMLResponse._from_response(r, self)


class AsyncHTMLSession(requests.Session):
Expand All @@ -667,16 +668,15 @@ def __init__(self, loop=None, workers=None,
if mock_browser:
self.headers['User-Agent'] = user_agent()

self.hooks["response"].append(self.response_hook)
self.hooks['response'].append(self.response_hook)

self.loop = loop or asyncio.get_event_loop()
self.thread_pool = ThreadPoolExecutor(max_workers=workers)

@staticmethod
def response_hook(response, **kwargs) -> HTMLResponse:
def response_hook(self, response, **kwargs) -> HTMLResponse:
""" Change response enconding and replace it by a HTMLResponse. """
response.encoding = DEFAULT_ENCODING
return HTMLResponse._from_response(response)
return HTMLResponse._from_response(response, self)

def request(self, *args, **kwargs):
""" Partial original request func and run it in a thread. """
Expand Down

0 comments on commit 2ef3d41

Please sign in to comment.