Skip to content

Commit

Permalink
Limit requests (#35)
Browse files Browse the repository at this point in the history
* Add RequestScheduler class to limit amount of simultaneous requests

* Proxy task with a future

* Remove done Queue, rename task to future

* Fix the tests

* Prevent task callback to set result/exception on done future

* Fix requests not being awaited when future was cancelled before request was scheduled

* Fix flake8 errors

* Remove requests from constructor, set schedule_request and on_completion private
  • Loading branch information
NicolasAubry authored and lphuberdeau committed Apr 19, 2018
1 parent cb5a559 commit 104d0cd
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 17 deletions.
35 changes: 18 additions & 17 deletions hammertime/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .http import Entry
from .ruleset import Heuristics, HammerTimeException
from .engine import RetryEngine
from .requestscheduler import RequestScheduler
import signal


Expand All @@ -45,6 +46,7 @@ def __init__(self, loop=None, request_engine=None, kb=None, retry_count=0, proxy
self.loop.add_signal_handler(signal.SIGINT, self._interrupt)
self._success_iterator = None
self._interrupted = False
self._request_scheduler = RequestScheduler(loop=loop)

@property
def completed_count(self):
Expand All @@ -66,16 +68,15 @@ def request(self, *args, **kwargs):
return future

self.stats.requested += 1
task = self.loop.create_task(self._request(*args, **kwargs))
self.tasks.append(task)
task.add_done_callback(self._on_completion)
return task
future = self._request_scheduler.request(self._request(*args, **kwargs))
self.tasks.append(future)
future.add_done_callback(self._on_completion)
return future

async def _request(self, *args, **kwargs):
try:
entry = Entry.create(*args, **kwargs)
entry = await self.request_engine.perform(entry, heuristics=self.heuristics)

return entry
except (HammerTimeException, asyncio.CancelledError):
raise
Expand All @@ -93,37 +94,37 @@ def successful_requests(self):
"You must call collect_successful_requests() prior to performing requests."
return self._success_iterator

def _on_completion(self, task):
self._drain(task)
self.tasks.remove(task)
def _on_completion(self, future):
self._drain(future)
self.tasks.remove(future)

if self._success_iterator:
# Checking exception conditions explicitly to avoid using try/except blocks
entry = task.result() if not task.cancelled() and not task.exception() else None
entry = future.result() if not future.cancelled() and not future.exception() else None

self._success_iterator.complete(entry)

def _drain(self, task):
def _drain(self, future):
try:
task.result()
future.result()
except (HammerTimeException, asyncio.CancelledError):
pass
except Exception as e:
logger.exception(e)

async def _cancel_tasks(self):
for t in self.tasks:
if not t.done():
t.cancel()
for future in self.tasks:
if not future.done():
future.cancel()
if len(self.tasks):
await asyncio.wait(self.tasks, loop=self.loop, return_when=asyncio.ALL_COMPLETED)

async def close(self):
if not self.is_closed:
await self._cancel_tasks()
for t in self.tasks:
if t.done():
self._drain(t)
for future in self.tasks:
if future.done():
self._drain(future)

if self.request_engine is not None:
await self.request_engine.close()
Expand Down
90 changes: 90 additions & 0 deletions hammertime/requestscheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# hammertime: A high-volume http fetch library
# Copyright (C) 2016- Delve Labs inc.
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.


from asyncio import Future
from collections import deque


class RequestScheduler:

def __init__(self, *, loop, limit=1000):
self.loop = loop
self.wait_queue = deque()
self.pending_requests = []
self.max_simultaneous_requests = limit

def request(self, request, *, schedule=True):
f = Future(loop=self.loop)
self.wait_queue.append((request, f))
if schedule:
self.schedule_max_possible_requests()
return f

def schedule_max_possible_requests(self):
while len(self.pending_requests) < self.max_simultaneous_requests:
try:
request, future = self.wait_queue.popleft()
if not future.done():
self._schedule_request(request, future)
else:
self._cancel_request(request)
except IndexError:
return

def _schedule_request(self, request, future=None):
task = self.loop.create_task(request)
task.add_done_callback(self._on_completion)
self.pending_requests.append(task)

if future:
task.add_done_callback(self._update_future(future))
future.add_done_callback(self._cancel_sub(task))

def _on_completion(self, task):
self.pending_requests.remove(task)
self.schedule_max_possible_requests()

def _update_future(self, future):
def complete(task):
if task.cancelled():
future.cancel()
else:
exc = task.exception()
if not future.done():
if exc:
future.set_exception(exc)
else:
future.set_result(task.result())

return complete

def _cancel_sub(self, task):
def complete(future):
if not task.done():
if future.cancelled():
task.cancel()

return complete

def _cancel_request(self, request):
task = self.loop.create_task(request)
task.cancel()
try:
task.result()
except Exception:
pass
84 changes: 84 additions & 0 deletions tests/requestscheduler_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# hammertime: A high-volume http fetch library
# Copyright (C) 2016- Delve Labs inc.
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.


import asyncio
from unittest import TestCase
from unittest.mock import MagicMock, call

from fixtures import async_test
from hammertime.requestscheduler import RequestScheduler


class TestRequestScheduler(TestCase):

def test_schedule_maximum_number_of_requests_on_creation(self):
loop = MagicMock()
requests = [i for i in range(100)]
limit = 10
loop.create_task = MagicMock(return_value=MagicMock())

scheduler = RequestScheduler(loop=loop, limit=limit)
for request in requests:
scheduler.request(request)

expected = [call(i) for i in range(limit)]
loop.create_task.assert_has_calls(expected)
self.assertEqual(scheduler.pending_requests, [loop.create_task.return_value]*limit)

def test_remove_scheduled_futures_from_wait_list(self):
loop = MagicMock()
request_count = 100
requests = [i for i in range(request_count)]
limit = 10

scheduler = RequestScheduler(loop=loop, limit=limit)
for request in requests:
scheduler.request(request)

self.assertEqual(len(scheduler.wait_queue), request_count - limit)

@async_test()
async def test_remove_completed_task_from_pending_requests_list(self, loop):
async def dummy_coro():
await asyncio.sleep(0)
request = dummy_coro()
scheduler = RequestScheduler(loop=loop)
future = scheduler.request(request)
await future
self.assertEqual(len(scheduler.pending_requests), 0)

@async_test()
async def test_schedule_waiting_task_when_task_is_done(self, loop):
async def dummy_coro(result):
await asyncio.wait_for(result, timeout=5)
return result

result0 = asyncio.Future(loop=loop)
result1 = asyncio.Future(loop=loop)
scheduler = RequestScheduler(loop=loop, limit=1)
future0 = scheduler.request(dummy_coro(result0))
future1 = scheduler.request(dummy_coro(result1))
result0.set_result(None)

self.assertEqual(await future0, result0)
self.assertEqual(len(scheduler.wait_queue), 0)
self.assertEqual(len(scheduler.pending_requests), 1)

result1.set_result(None)
self.assertEqual(await future1, result1)
self.assertEqual(len(scheduler.pending_requests), 0)

0 comments on commit 104d0cd

Please sign in to comment.