Skip to content

Commit

Permalink
Merge pull request psf#101 from oldani/feature/async_support
Browse files Browse the repository at this point in the history
Feature/async support
  • Loading branch information
kennethreitz committed Mar 6, 2018
2 parents 3cbcb12 + c7ba3c1 commit 6ab1aff
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 1 deletion.
1 change: 1 addition & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pytest = "*"
"e1839a8" = {path = ".", editable = true}
sphinx = "*"
mypy = "*"
pytest-asyncio = "*"


[scripts]
Expand Down
36 changes: 36 additions & 0 deletions requests_html.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import sys
import asyncio
from urllib.parse import urlparse, urlunparse
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures._base import TimeoutError
from functools import partial
from typing import Set, Union, List, MutableMapping, Optional

import pyppeteer
Expand Down Expand Up @@ -599,3 +601,37 @@ def request(self, *args, **kwargs) -> HTMLResponse:
r = super(HTMLSession, self).request(*args, **kwargs)

return HTMLResponse._from_response(r)


class AsyncHTMLSession(requests.Session):
""" An async consumable session. """

def __init__(self, loop=None, workers=None,
mock_browser: bool = True, *args, **kwargs):
""" Set or create an event loop and a thread pool.
:param loop: Asyncio lopp to use.
:param workers: Amount of threads to use for executing async calls.
If not pass it will default to the number of processors on the
machine, multiplied by 5. """
super().__init__(*args, **kwargs)

# Mock a web browser's user agent.
if mock_browser:
self.headers['User-Agent'] = user_agent()

self.hooks["response"].append(self.response_hook)

self.loop = loop or asyncio.get_event_loop()
self.thread_pool = ThreadPoolExecutor(max_workers=workers)

@staticmethod
def response_hook(response, **kwargs) -> HTMLResponse:
""" Change response enconding and replace it by a HTMLResponse. """
response.encoding = DEFAULT_ENCODING
return HTMLResponse._from_response(response)

def request(self, *args, **kwargs):
""" Partial original request func and run it in a thread. """
func = partial(super().request, *args, **kwargs)
return self.loop.run_in_executor(self.thread_pool, func)
33 changes: 32 additions & 1 deletion tests/test_requests_html.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from functools import partial

import pytest
from requests_html import HTMLSession, HTML
from requests_html import HTMLSession, AsyncHTMLSession, HTML
from requests_file import FileAdapter

session = HTMLSession()
Expand All @@ -15,12 +16,31 @@ def get():
return session.get(url)


@pytest.fixture
def async_get(event_loop):
""" AsyncSession cannot be created global since it will create
a different loop from pytest-asyncio. """
async_session = AsyncHTMLSession()
async_session.mount('file://', FileAdapter())
path = os.path.sep.join((os.path.dirname(os.path.abspath(__file__)), 'python.html'))
url = 'file://{}'.format(path)

return partial(async_session.get, url)


@pytest.mark.ok
def test_file_get():
r = get()
assert r.status_code == 200


@pytest.mark.ok
@pytest.mark.asyncio
async def test_async_file_get(async_get):
r = await async_get()
assert r.status_code == 200


@pytest.mark.ok
def test_class_seperation():
r = get()
Expand Down Expand Up @@ -53,6 +73,7 @@ def test_containing():
for e in python:
assert 'python' in e.full_text.lower()


@pytest.mark.ok
def test_attrs():
r = get()
Expand All @@ -71,6 +92,16 @@ def test_links():
assert len(about.absolute_links) == 6


@pytest.mark.ok
@pytest.mark.asyncio
async def test_async_links(async_get):
r = await async_get()
about = r.html.find('#about', first=True)

assert len(about.links) == 6
assert len(about.absolute_links) == 6


@pytest.mark.ok
def test_search():
r = get()
Expand Down

0 comments on commit 6ab1aff

Please sign in to comment.