-
Notifications
You must be signed in to change notification settings - Fork 22
/
leakybucket.py
129 lines (105 loc) · 4.25 KB
/
leakybucket.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# SPDX-License-Identifier: MIT
# Copyright (c) 2019 Martijn Pieters
# Licensed under the MIT license as detailed in LICENSE.txt
import asyncio
from contextlib import AbstractAsyncContextManager
from types import TracebackType
from typing import Dict, Optional, Type
from .compat import wait_for
class AsyncLimiter(AbstractAsyncContextManager):
"""A leaky bucket rate limiter.
This is an :ref:`asynchronous context manager <async-context-managers>`;
when used with :keyword:`async with`, entering the context acquires
capacity::
limiter = AsyncLimiter(10)
for foo in bar:
async with limiter:
# process foo elements at 10 items per minute
:param max_rate: Allow up to `max_rate` / `time_period` acquisitions before
blocking.
:param time_period: duration, in seconds, of the time period in which to
limit the rate. Note that up to `max_rate` acquisitions are allowed
within this time period in a burst.
"""
__slots__ = (
"max_rate",
"time_period",
"_rate_per_sec",
"_level",
"_last_check",
"_waiters",
)
max_rate: float #: The configured `max_rate` value for this limiter.
time_period: float #: The configured `time_period` value for this limiter.
def __init__(self, max_rate: float, time_period: float = 60) -> None:
self.max_rate = max_rate
self.time_period = time_period
self._rate_per_sec = max_rate / time_period
self._level = 0.0
self._last_check = 0.0
# queue of waiting futures to signal capacity to
self._waiters: Dict[asyncio.Task, asyncio.Future] = {}
def _leak(self) -> None:
"""Drip out capacity from the bucket."""
loop = asyncio.get_running_loop()
if self._level:
# drip out enough level for the elapsed time since
# we last checked
elapsed = loop.time() - self._last_check
decrement = elapsed * self._rate_per_sec
self._level = max(self._level - decrement, 0)
self._last_check = loop.time()
def has_capacity(self, amount: float = 1) -> bool:
"""Check if there is enough capacity remaining in the limiter
:param amount: How much capacity you need to be available.
"""
self._leak()
requested = self._level + amount
# if there are tasks waiting for capacity, signal to the first
# there there may be some now (they won't wake up until this task
# yields with an await)
if requested < self.max_rate:
for fut in self._waiters.values():
if not fut.done():
fut.set_result(True)
break
return self._level + amount <= self.max_rate
async def acquire(self, amount: float = 1) -> None:
"""Acquire capacity in the limiter.
If the limit has been reached, blocks until enough capacity has been
freed before returning.
:param amount: How much capacity you need to be available.
:exception: Raises :exc:`ValueError` if `amount` is greater than
:attr:`max_rate`.
"""
if amount > self.max_rate:
raise ValueError("Can't acquire more than the maximum capacity")
loop = asyncio.get_running_loop()
task = asyncio.current_task(loop)
assert task is not None
while not self.has_capacity(amount):
# wait for the next drip to have left the bucket
# add a future to the _waiters map to be notified
# 'early' if capacity has come up
fut = loop.create_future()
self._waiters[task] = fut
try:
await wait_for(
asyncio.shield(fut), 1 / self._rate_per_sec * amount, loop=loop
)
except asyncio.TimeoutError:
pass
fut.cancel()
self._waiters.pop(task, None)
self._level += amount
return None
async def __aenter__(self) -> None:
await self.acquire()
return None
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
return None