diff --git a/src/gradient/_utils/__init__.py b/src/gradient/_utils/__init__.py index dc64e29a..36bd62ef 100644 --- a/src/gradient/_utils/__init__.py +++ b/src/gradient/_utils/__init__.py @@ -29,6 +29,7 @@ get_required_header as get_required_header, maybe_coerce_boolean as maybe_coerce_boolean, maybe_coerce_integer as maybe_coerce_integer, + RateLimiter as RateLimiter, ) from ._compat import ( get_args as get_args, diff --git a/src/gradient/_utils/_utils.py b/src/gradient/_utils/_utils.py index 50d59269..08e3aee7 100644 --- a/src/gradient/_utils/_utils.py +++ b/src/gradient/_utils/_utils.py @@ -419,3 +419,48 @@ def json_safe(data: object) -> object: return data.isoformat() return data + + +# Rate Limiting Classes +class RateLimiter: + """Simple token bucket rate limiter.""" + + def __init__(self, requests_per_minute: int = 60) -> None: + """Initialize rate limiter. + + Args: + requests_per_minute: Maximum requests allowed per minute + """ + self.requests_per_minute: int = requests_per_minute + self.tokens: float = float(requests_per_minute) + self.last_refill: float = self._now() + self.refill_rate: float = requests_per_minute / 60.0 # tokens per second + + def _now(self) -> float: + """Get current time in seconds.""" + import time + return time.time() + + def _refill(self) -> None: + """Refill tokens based on elapsed time.""" + now = self._now() + elapsed = now - self.last_refill + self.tokens = min(self.requests_per_minute, self.tokens + elapsed * self.refill_rate) + self.last_refill = now + + def acquire(self, tokens: int = 1) -> bool: + """Try to acquire tokens. Returns True if successful.""" + self._refill() + if self.tokens >= tokens: + self.tokens -= tokens + return True + return False + + def wait_time(self, tokens: int = 1) -> float: + """Get seconds to wait for tokens to be available.""" + self._refill() + if self.tokens >= tokens: + return 0.0 + + needed = tokens - self.tokens + return needed / self.refill_rate diff --git a/tests/test_rate_limiter.py b/tests/test_rate_limiter.py new file mode 100644 index 00000000..3e75be54 --- /dev/null +++ b/tests/test_rate_limiter.py @@ -0,0 +1,56 @@ +"""Tests for rate limiting functionality.""" + +import time +import pytest +from gradient._utils import RateLimiter + + +class TestRateLimiter: + """Test rate limiting functionality.""" + + def test_rate_limiter_basic(self): + """Test basic rate limiter operations.""" + limiter = RateLimiter(requests_per_minute=10) + + # Should allow initial requests + assert limiter.acquire() is True + assert limiter.acquire() is True + + # Should deny when tokens exhausted + limiter.tokens = 0 # Force exhaustion + assert limiter.acquire() is False + + def test_rate_limiter_wait_time(self): + """Test wait time calculation.""" + limiter = RateLimiter(requests_per_minute=60) # 1 request per second + + # Exhaust tokens + limiter.tokens = 0 + + # Should calculate correct wait time + wait_time = limiter.wait_time() + assert wait_time > 0 + assert wait_time <= 1.0 # Should not exceed 1 second + + def test_rate_limiter_refill(self): + """Test token refill over time.""" + limiter = RateLimiter(requests_per_minute=60) # 1 token per second + + # Exhaust tokens + limiter.tokens = 0 + start_time = limiter._now() + + # Wait for refill + time.sleep(0.1) + + # Should have refilled some tokens + limiter._refill() + assert limiter.tokens > 0 + + def test_rate_limiter_custom_rate(self): + """Test custom rate limits.""" + limiter = RateLimiter(requests_per_minute=120) # 2 requests per second + + # Should have double the tokens of default + assert limiter.requests_per_minute == 120 + assert limiter.refill_rate == 2.0 \ No newline at end of file