Skip to content

Commit

Permalink
Switched fully to ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
colin99d committed Aug 21, 2023
1 parent e67f7f2 commit cae4853
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 79 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ redis = ["redis"]
ignore-init-module-imports = true
line-length=105
select = ["E", "W", "F", "Q", "W", "S", "UP", "I", "PD", "SIM", "PLC", "PLE", "PLR", "PLW", "T20", "PYI"]
target-version = "py37"

[tool.ruff.per-file-ignores]
"tests/*" = ["S101", "S105", "PLR2004", "PLR0913", "S311"]
2 changes: 1 addition & 1 deletion slowapi/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ def __init__(self, limit: Limit) -> None:
)
else:
description = str(limit.limit)
super(RateLimitExceeded, self).__init__(status_code=429, detail=description)
super().__init__(status_code=429, detail=description)
47 changes: 23 additions & 24 deletions slowapi/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class Limiter:
* **key_style**: set to "url" to use the url, "endpoint" to use the view_func
"""

def __init__(
def __init__( # noqa: PLR0913,PLR0915
self,
# app: Starlette = None,
key_func: Callable[..., str],
Expand Down Expand Up @@ -325,8 +325,8 @@ def slowapi_startup(self) -> None:
"""
Starlette startup event handler that links the app with the Limiter instance.
"""
app.state.limiter = self # type: ignore
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # type: ignore
app.state.limiter = self # type: ignore # noqa: F821
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # type: ignore # noqa

def get_app_config(self, key: str, default_value: T = None) -> T:
"""
Expand Down Expand Up @@ -363,9 +363,10 @@ def limiter(self) -> RateLimiter:
The backend that keeps track of consumption of endpoints vs limits
"""
if self._storage_dead and self._in_memory_fallback_enabled:
assert (
self._fallback_limiter
), "Fallback limiter is needed when in memory fallback is enabled"
if not self._fallback_limiter:
raise AssertionError(
"Fallback limiter is needed when in memory fallback is enabled"
)
return self._fallback_limiter
else:
return self._limiter
Expand Down Expand Up @@ -407,7 +408,7 @@ def _inject_headers(
if self._retry_after == "http-date"
else str(int(reset_in - time.time()))
)
except:
except: # noqa: E722
if self._in_memory_fallback and not self._storage_dead:
self.logger.warning(
"Rate limit storage unreachable - falling back to"
Expand Down Expand Up @@ -489,7 +490,7 @@ def __evaluate_limits(
if lim.per_method:
limit_scope += ":%s" % request.method

if "request" in inspect.signature(lim.key_func).parameters.keys():
if "request" in inspect.signature(lim.key_func).parameters:
limit_key = lim.key_func(request)
else:
limit_key = lim.key_func()
Expand Down Expand Up @@ -543,7 +544,7 @@ def _determine_retry_time(self, retry_header_value) -> int:

return int(time.time() + retry_after_int)

def _check_request_limit(
def _check_request_limit( # noqa: PLR0912
self,
request: Request,
endpoint_func: Optional[Callable[..., Any]],
Expand Down Expand Up @@ -595,13 +596,12 @@ def _check_request_limit(
if self._storage_dead and self._fallback_limiter:
if in_middleware and endpoint_func_name in self.__marked_for_limiting:
pass
elif self.__should_check_backend() and self._storage.check():
self.logger.info("Rate limit storage recovered")
self._storage_dead = False
self.__check_backend_count = 0
else:
if self.__should_check_backend() and self._storage.check():
self.logger.info("Rate limit storage recovered")
self._storage_dead = False
self.__check_backend_count = 0
else:
all_limits = list(itertools.chain(*self._in_memory_fallback))
all_limits = list(itertools.chain(*self._in_memory_fallback))
if not all_limits:
route_limits: List[Limit] = limits + dynamic_limits
all_limits = (
Expand Down Expand Up @@ -634,13 +634,12 @@ def _check_request_limit(
)
self._storage_dead = True
self._check_request_limit(request, endpoint_func, in_middleware)
elif self._swallow_errors:
self.logger.exception("Failed to rate limit. Swallowing error")
else:
if self._swallow_errors:
self.logger.exception("Failed to rate limit. Swallowing error")
else:
raise
raise

def __limit_decorator(
def __limit_decorator( # noqa: PLR0913,PLR0915
self,
limit_value: StrOrCallableStr,
key_func: Optional[Callable[..., str]] = None,
Expand Down Expand Up @@ -701,7 +700,7 @@ def decorator(func: Callable[..., Response]):

sig = inspect.signature(func)
for idx, parameter in enumerate(sig.parameters.values()):
if parameter.name == "request" or parameter.name == "websocket":
if parameter.name in ("request", "websocket"):
break
else:
raise Exception(
Expand Down Expand Up @@ -774,7 +773,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Response:

return decorator

def limit(
def limit( # noqa: PLR0913
self,
limit_value: StrOrCallableStr,
key_func: Optional[Callable[..., str]] = None,
Expand Down Expand Up @@ -814,7 +813,7 @@ def limit(
override_defaults=override_defaults,
)

def shared_limit(
def shared_limit( # noqa: PLR0913
self,
limit_value: Union[str, Callable[[str], str]],
scope: StrOrCallableStr,
Expand Down Expand Up @@ -859,7 +858,7 @@ def exempt(self, obj):
"""
Decorator to mark a view as exempt from rate limits.
"""
name = "%s.%s" % (obj.__module__, obj.__name__)
name = f"{obj.__module__}.{obj.__name__}"

self._exempt_routes.add(name)

Expand Down
2 changes: 1 addition & 1 deletion slowapi/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Match
from starlette.types import ASGIApp, Message, Scope, Receive, Send
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from slowapi import Limiter, _rate_limit_exceeded_handler

Expand Down
24 changes: 14 additions & 10 deletions slowapi/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from starlette.requests import Request


class Limit(object):
class Limit:
"""
simple wrapper to encapsulate limits and their context
"""

def __init__(
def __init__( # noqa: PLR0913
self,
limit: RateLimitItem,
key_func: Callable[..., str],
Expand All @@ -29,7 +29,9 @@ def __init__(
self.methods = methods
self.error_message = error_message
self.exempt_when = exempt_when
self._exempt_when_takes_request = len(inspect.signature(self.exempt_when).parameters) == 1
self._exempt_when_takes_request = (
len(inspect.signature(self.exempt_when).parameters) == 1
)
self.cost = cost
self.override_defaults = override_defaults

Expand All @@ -56,18 +58,18 @@ def scope(self) -> str:
return ""
else:
return (
self.__scope(request.endpoint) # type: ignore
self.__scope(request.endpoint) # noqa: F821 # type: ignore
if callable(self.__scope)
else self.__scope
)


class LimitGroup(object):
class LimitGroup:
"""
represents a group of related limits either from a string or a callable that returns one
"""

def __init__(
def __init__( # noqa: PLR0913
self,
limit_provider: Union[str, Callable[..., str]],
key_function: Callable[..., str],
Expand All @@ -92,10 +94,12 @@ def __init__(

def __iter__(self) -> Iterator[Limit]:
if callable(self.__limit_provider):
if "key" in inspect.signature(self.__limit_provider).parameters.keys():
assert (
"request" in inspect.signature(self.key_function).parameters.keys()
), f"Limit provider function {self.key_function.__name__} needs a `request` argument"
if "key" in inspect.signature(self.__limit_provider).parameters:
if "request" not in inspect.signature(self.key_function).parameters:
func_name = self.key_function.__name__
raise AssertionError(
f"Limit provider function {func_name} needs a `request` argument"
)
if self.request is None:
raise Exception("`request` object can't be None")
limit_raw = self.__limit_provider(self.key_function(self.request))
Expand Down
4 changes: 2 additions & 2 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import asyncio
import logging
from unittest import mock # type: ignore

import pytest
from fastapi import FastAPI
from mock import mock # type: ignore
from starlette.applications import Starlette
from starlette.requests import Request

from slowapi.errors import RateLimitExceeded
from slowapi.extension import Limiter, _rate_limit_exceeded_handler
from slowapi.middleware import SlowAPIMiddleware, SlowAPIASGIMiddleware
from slowapi.middleware import SlowAPIASGIMiddleware, SlowAPIMiddleware
from slowapi.util import get_remote_address


Expand Down
6 changes: 3 additions & 3 deletions tests/test_fastapi_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_multiple_decorators(self, build_fastapi_app):
async def t1(request: Request):
return PlainTextResponse("test")

with hiro.Timeline().freeze() as timeline:
with hiro.Timeline().freeze():
cli = TestClient(app)
for i in range(0, 100):
response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"})
Expand All @@ -105,7 +105,7 @@ def test_multiple_decorators_not_response(self, build_fastapi_app):
async def t1(request: Request, response: Response):
return {"key": "value"}

with hiro.Timeline().freeze() as timeline:
with hiro.Timeline().freeze():
cli = TestClient(app)
for i in range(0, 100):
response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"})
Expand All @@ -130,7 +130,7 @@ def test_multiple_decorators_not_response_with_headers(self, build_fastapi_app):
async def t1(request: Request, response: Response):
return {"key": "value"}

with hiro.Timeline().freeze() as timeline:
with hiro.Timeline().freeze():
cli = TestClient(app)
for i in range(0, 100):
response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"})
Expand Down
73 changes: 35 additions & 38 deletions tests/test_starlette_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def t1(request: Request):

app.add_route("/t1", t1)

with hiro.Timeline().freeze() as timeline:
with hiro.Timeline().freeze():
cli = TestClient(app)
for i in range(0, 10):
response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"})
Expand All @@ -100,7 +100,7 @@ async def t1(request: Request):

app.add_route("/t1", t1)

with hiro.Timeline().freeze() as timeline:
with hiro.Timeline().freeze():
cli = TestClient(app)
for i in range(0, 10):
response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"})
Expand Down Expand Up @@ -130,23 +130,22 @@ def t1(request: Request):
def t2(request: Request):
return PlainTextResponse("test")

with hiro.Timeline().freeze():
with TestClient(app) as cli:
resp = cli.get("/t1")
assert resp.headers.get("X-RateLimit-Limit") == "10"
assert resp.headers.get("X-RateLimit-Remaining") == "9"
assert resp.headers.get("X-RateLimit-Reset") == str(
int(time.time() + 61)
)
assert resp.headers.get("Retry-After") == str(60)
resp = cli.get("/t2")
assert resp.headers.get("X-RateLimit-Limit") == "2"
assert resp.headers.get("X-RateLimit-Remaining") == "1"
assert resp.headers.get("X-RateLimit-Reset") == str(
int(time.time() + 2)
)
with hiro.Timeline().freeze(), TestClient(app) as cli:
resp = cli.get("/t1")
assert resp.headers.get("X-RateLimit-Limit") == "10"
assert resp.headers.get("X-RateLimit-Remaining") == "9"
assert resp.headers.get("X-RateLimit-Reset") == str(
int(time.time() + 61)
)
assert resp.headers.get("Retry-After") == str(60)
resp = cli.get("/t2")
assert resp.headers.get("X-RateLimit-Limit") == "2"
assert resp.headers.get("X-RateLimit-Remaining") == "1"
assert resp.headers.get("X-RateLimit-Reset") == str(
int(time.time() + 2)
)

assert resp.headers.get("Retry-After") == str(1)
assert resp.headers.get("Retry-After") == str(1)

def test_headers_breach(self, build_starlette_app):
app, limiter = build_starlette_app(
Expand All @@ -158,18 +157,17 @@ def test_headers_breach(self, build_starlette_app):
def t(request: Request):
return PlainTextResponse("test")

with hiro.Timeline().freeze() as timeline:
with TestClient(app) as cli:
for i in range(11):
resp = cli.get("/t1")
timeline.forward(1)
with hiro.Timeline().freeze() as timeline, TestClient(app) as cli:
for i in range(11):
resp = cli.get("/t1")
timeline.forward(1)

assert resp.headers.get("X-RateLimit-Limit") == "10"
assert resp.headers.get("X-RateLimit-Remaining") == "0"
assert resp.headers.get("X-RateLimit-Reset") == str(
int(time.time() + 50)
)
assert resp.headers.get("Retry-After") == str(int(50))
assert resp.headers.get("X-RateLimit-Limit") == "10"
assert resp.headers.get("X-RateLimit-Remaining") == "0"
assert resp.headers.get("X-RateLimit-Reset") == str(
int(time.time() + 50)
)
assert resp.headers.get("Retry-After") == str(50)

def test_retry_after(self, build_starlette_app):
# FIXME: this test is not actually running!
Expand All @@ -183,14 +181,13 @@ def test_retry_after(self, build_starlette_app):
def t(request: Request):
return PlainTextResponse("test")

with hiro.Timeline().freeze() as timeline:
with TestClient(app) as cli:
resp = cli.get("/t1")
retry_after = int(resp.headers.get("Retry-After"))
assert retry_after > 0
timeline.forward(retry_after)
resp = cli.get("/t1")
assert resp.status_code == 200
with hiro.Timeline().freeze() as timeline, TestClient(app) as cli:
resp = cli.get("/t1")
retry_after = int(resp.headers.get("Retry-After"))
assert retry_after > 0
timeline.forward(retry_after)
resp = cli.get("/t1")
assert resp.status_code == 200

def test_exempt_decorator(self, build_starlette_app):
app, limiter = build_starlette_app(
Expand Down Expand Up @@ -245,7 +242,7 @@ async def t1(request: Request):

app.add_route("/t1", t1)

with hiro.Timeline().freeze() as timeline:
with hiro.Timeline().freeze():
cli = TestClient(app)
for i in range(0, 10):
response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"})
Expand Down

0 comments on commit cae4853

Please sign in to comment.