From cac09d1b1f0de297f77de1d45827edfd4e6d68ff Mon Sep 17 00:00:00 2001 From: Igor Benav Date: Mon, 10 Feb 2025 18:52:44 -0300 Subject: [PATCH 1/2] rate limiter changed from module to class --- README.md | 8 ++-- src/app/api/dependencies.py | 9 +++-- src/app/api/v1/rate_limits.py | 2 +- src/app/api/v1/tasks.py | 4 +- src/app/core/setup.py | 53 +++++++++++++++++---------- src/app/core/utils/rate_limit.py | 63 ++++++++++++++++++++++---------- 6 files changed, 90 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index aec89f57..a483187a 100644 --- a/README.md +++ b/README.md @@ -1342,17 +1342,17 @@ async def your_background_function( ### 5.11 Rate Limiting -To limit how many times a user can make a request in a certain interval of time (very useful to create subscription plans or just to protect your API against DDOS), you may just use the `rate_limiter` dependency: +To limit how many times a user can make a request in a certain interval of time (very useful to create subscription plans or just to protect your API against DDOS), you may just use the `rate_limiter_dependency` dependency: ```python from fastapi import Depends -from app.api.dependencies import rate_limiter +from app.api.dependencies import rate_limiter_dependency from app.core.utils import queue from app.schemas.job import Job -@router.post("/task", response_model=Job, status_code=201, dependencies=[Depends(rate_limiter)]) +@router.post("/task", response_model=Job, status_code=201, dependencies=[Depends(rate_limiter_dependency)]) async def create_task(message: str): job = await queue.pool.enqueue_job("sample_background_task", message) return {"id": job.job_id} @@ -1446,7 +1446,7 @@ curl -X POST 'http://127.0.0.1:8000/api/v1/tasks/task?message=test' \ ``` > \[!TIP\] -> Since the `rate_limiter` dependency uses the `get_optional_user` dependency instead of `get_current_user`, it will not require authentication to be used, but will behave accordingly if the user is authenticated (and token is passed in header). If you want to ensure authentication, also use `get_current_user` if you need. +> Since the `rate_limiter_dependency` dependency uses the `get_optional_user` dependency instead of `get_current_user`, it will not require authentication to be used, but will behave accordingly if the user is authenticated (and token is passed in header). If you want to ensure authentication, also use `get_current_user` if you need. To change a user's tier, you may just use the `PATCH api/v1/user/{username}/tier` endpoint. Note that for flexibility (since this is a boilerplate), it's not necessary to previously inform a tier_id to create a user, but you probably should set every user to a certain tier (let's say `free`) once they are created. diff --git a/src/app/api/dependencies.py b/src/app/api/dependencies.py index f9297427..dfb9c93a 100644 --- a/src/app/api/dependencies.py +++ b/src/app/api/dependencies.py @@ -8,7 +8,7 @@ from ..core.exceptions.http_exceptions import ForbiddenException, RateLimitException, UnauthorizedException from ..core.logger import logging from ..core.security import oauth2_scheme, verify_token -from ..core.utils.rate_limit import is_rate_limited +from ..core.utils.rate_limit import rate_limiter from ..crud.crud_rate_limit import crud_rate_limits from ..crud.crud_tier import crud_tiers from ..crud.crud_users import crud_users @@ -72,9 +72,12 @@ async def get_current_superuser(current_user: Annotated[dict, Depends(get_curren return current_user -async def rate_limiter( +async def rate_limiter_dependency( request: Request, db: Annotated[AsyncSession, Depends(async_get_db)], user: User | None = Depends(get_optional_user) ) -> None: + if hasattr(request.app.state, "initialization_complete"): + await request.app.state.initialization_complete.wait() + path = sanitize_path(request.url.path) if user: user_id = user["id"] @@ -96,6 +99,6 @@ async def rate_limiter( user_id = request.client.host limit, period = DEFAULT_LIMIT, DEFAULT_PERIOD - is_limited = await is_rate_limited(db=db, user_id=user_id, path=path, limit=limit, period=period) + is_limited = await rate_limiter.is_rate_limited(db=db, user_id=user_id, path=path, limit=limit, period=period) if is_limited: raise RateLimitException("Rate limit exceeded.") diff --git a/src/app/api/v1/rate_limits.py b/src/app/api/v1/rate_limits.py index 479a903b..4c81c15b 100644 --- a/src/app/api/v1/rate_limits.py +++ b/src/app/api/v1/rate_limits.py @@ -6,7 +6,7 @@ from ...api.dependencies import get_current_superuser from ...core.db.database import async_get_db -from ...core.exceptions.http_exceptions import DuplicateValueException, NotFoundException, RateLimitException +from ...core.exceptions.http_exceptions import DuplicateValueException, NotFoundException from ...crud.crud_rate_limit import crud_rate_limits from ...crud.crud_tier import crud_tiers from ...schemas.rate_limit import RateLimitCreate, RateLimitCreateInternal, RateLimitRead, RateLimitUpdate diff --git a/src/app/api/v1/tasks.py b/src/app/api/v1/tasks.py index 48520006..ca714c8b 100644 --- a/src/app/api/v1/tasks.py +++ b/src/app/api/v1/tasks.py @@ -3,14 +3,14 @@ from arq.jobs import Job as ArqJob from fastapi import APIRouter, Depends -from ...api.dependencies import rate_limiter +from ...api.dependencies import rate_limiter_dependency from ...core.utils import queue from ...schemas.job import Job router = APIRouter(prefix="/tasks", tags=["tasks"]) -@router.post("/task", response_model=Job, status_code=201, dependencies=[Depends(rate_limiter)]) +@router.post("/task", response_model=Job, status_code=201, dependencies=[Depends(rate_limiter_dependency)]) async def create_task(message: str) -> dict[str, str]: """Create a new background task. diff --git a/src/app/core/setup.py b/src/app/core/setup.py index a9c36621..2ddadf4e 100644 --- a/src/app/core/setup.py +++ b/src/app/core/setup.py @@ -12,7 +12,9 @@ from fastapi.openapi.utils import get_openapi from ..api.dependencies import get_current_superuser +from ..core.utils.rate_limit import rate_limiter from ..middleware.client_cache_middleware import ClientCacheMiddleware +from ..models import * from .config import ( AppSettings, ClientSideCacheSettings, @@ -24,9 +26,10 @@ RedisRateLimiterSettings, settings, ) -from .db.database import Base, async_engine as engine +from .db.database import Base +from .db.database import async_engine as engine from .utils import cache, queue, rate_limit -from ..models import * + # -------------- database -------------- async def create_tables() -> None: @@ -55,8 +58,7 @@ async def close_redis_queue_pool() -> None: # -------------- rate limit -------------- async def create_redis_rate_limit_pool() -> None: - rate_limit.pool = redis.ConnectionPool.from_url(settings.REDIS_RATE_LIMIT_URL) - rate_limit.client = redis.Redis.from_pool(rate_limit.pool) # type: ignore + rate_limiter.initialize(settings.REDIS_RATE_LIMIT_URL) # type: ignore async def close_redis_rate_limit_pool() -> None: @@ -85,30 +87,43 @@ def lifespan_factory( @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator: + from asyncio import Event + + initialization_complete = Event() + app.state.initialization_complete = initialization_complete + + print("1. Starting lifespan") await set_threadpool_tokens() + print("2. Set threadpool tokens") - if isinstance(settings, DatabaseSettings) and create_tables_on_start: - await create_tables() + try: + if isinstance(settings, RedisCacheSettings): + print("3. Starting Redis cache initialization") + await create_redis_cache_pool() - if isinstance(settings, RedisCacheSettings): - await create_redis_cache_pool() + if isinstance(settings, RedisQueueSettings): + print("4. Starting Redis queue initialization") + await create_redis_queue_pool() - if isinstance(settings, RedisQueueSettings): - await create_redis_queue_pool() + if isinstance(settings, RedisRateLimiterSettings): + print("5. Starting Redis rate limit initialization") + await create_redis_rate_limit_pool() - if isinstance(settings, RedisRateLimiterSettings): - await create_redis_rate_limit_pool() + print("6. All initialization complete") + initialization_complete.set() - yield + yield - if isinstance(settings, RedisCacheSettings): - await close_redis_cache_pool() + finally: + print("7. Starting shutdown") + if isinstance(settings, RedisCacheSettings): + await close_redis_cache_pool() - if isinstance(settings, RedisQueueSettings): - await close_redis_queue_pool() + if isinstance(settings, RedisQueueSettings): + await close_redis_queue_pool() - if isinstance(settings, RedisRateLimiterSettings): - await close_redis_rate_limit_pool() + if isinstance(settings, RedisRateLimiterSettings): + await close_redis_rate_limit_pool() return lifespan diff --git a/src/app/core/utils/rate_limit.py b/src/app/core/utils/rate_limit.py index 1c1ba7c6..4919e4d2 100644 --- a/src/app/core/utils/rate_limit.py +++ b/src/app/core/utils/rate_limit.py @@ -1,4 +1,5 @@ from datetime import UTC, datetime +from typing import Optional from redis.asyncio import ConnectionPool, Redis from sqlalchemy.ext.asyncio import AsyncSession @@ -8,31 +9,53 @@ logger = logging.getLogger(__name__) -pool: ConnectionPool | None = None -client: Redis | None = None +class RateLimiter: + _instance: Optional["RateLimiter"] = None + pool: Optional[ConnectionPool] = None + client: Optional[Redis] = None -async def is_rate_limited(db: AsyncSession, user_id: int, path: str, limit: int, period: int) -> bool: - if client is None: - logger.error("Redis client is not initialized.") - raise Exception("Redis client is not initialized.") + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance - current_timestamp = int(datetime.now(UTC).timestamp()) - window_start = current_timestamp - (current_timestamp % period) + @classmethod + def initialize(cls, redis_url: str) -> None: + instance = cls() + if instance.pool is None: + instance.pool = ConnectionPool.from_url(redis_url) + instance.client = Redis(connection_pool=instance.pool) - sanitized_path = sanitize_path(path) - key = f"ratelimit:{user_id}:{sanitized_path}:{window_start}" + @classmethod + def get_client(cls) -> Redis: + instance = cls() + if instance.client is None: + logger.error("Redis client is not initialized.") + raise Exception("Redis client is not initialized.") + return instance.client - try: - current_count = await client.incr(key) - if current_count == 1: - await client.expire(key, period) + async def is_rate_limited(self, db: AsyncSession, user_id: int, path: str, limit: int, period: int) -> bool: + client = self.get_client() + current_timestamp = int(datetime.now(UTC).timestamp()) + window_start = current_timestamp - (current_timestamp % period) - if current_count > limit: - return True + sanitized_path = sanitize_path(path) + key = f"ratelimit:{user_id}:{sanitized_path}:{window_start}" - except Exception as e: - logger.exception(f"Error checking rate limit for user {user_id} on path {path}: {e}") - raise e + try: + current_count = await client.incr(key) + if current_count == 1: + await client.expire(key, period) - return False + if current_count > limit: + return True + + except Exception as e: + logger.exception(f"Error checking rate limit for user {user_id} on path {path}: {e}") + raise e + + return False + + +rate_limiter = RateLimiter() From a0cad4c6ca55cb7df1fcff913b7e3ebed3ace294 Mon Sep 17 00:00:00 2001 From: Igor Benav Date: Mon, 10 Feb 2025 19:02:45 -0300 Subject: [PATCH 2/2] debug prints removed --- src/app/core/setup.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/app/core/setup.py b/src/app/core/setup.py index 2ddadf4e..071f849a 100644 --- a/src/app/core/setup.py +++ b/src/app/core/setup.py @@ -92,30 +92,23 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: initialization_complete = Event() app.state.initialization_complete = initialization_complete - print("1. Starting lifespan") await set_threadpool_tokens() - print("2. Set threadpool tokens") try: if isinstance(settings, RedisCacheSettings): - print("3. Starting Redis cache initialization") await create_redis_cache_pool() if isinstance(settings, RedisQueueSettings): - print("4. Starting Redis queue initialization") await create_redis_queue_pool() if isinstance(settings, RedisRateLimiterSettings): - print("5. Starting Redis rate limit initialization") await create_redis_rate_limit_pool() - print("6. All initialization complete") initialization_complete.set() yield finally: - print("7. Starting shutdown") if isinstance(settings, RedisCacheSettings): await close_redis_cache_pool()