diff --git a/fastapi_users/authentication/authenticator.py b/fastapi_users/authentication/authenticator.py index 7fab4b78..772e4260 100644 --- a/fastapi_users/authentication/authenticator.py +++ b/fastapi_users/authentication/authenticator.py @@ -8,6 +8,7 @@ from fastapi_users import models from fastapi_users.authentication.backend import AuthenticationBackend from fastapi_users.authentication.strategy import Strategy +from fastapi_users.authentication.token import UserTokenData from fastapi_users.manager import BaseUserManager, UserManagerDependency from fastapi_users.types import DependencyCallable @@ -62,6 +63,7 @@ def current_user_token( active: bool = False, verified: bool = False, superuser: bool = False, + fresh: bool = False, get_enabled_backends: Optional[EnabledBackendsDependency] = None, ): """ @@ -89,14 +91,16 @@ def current_user_token( @with_signature(signature) async def current_user_token_dependency(*args, **kwargs): - return await self._authenticate( + token_data, token = await self._authenticate( *args, optional=optional, active=active, verified=verified, superuser=superuser, + fresh=fresh, **kwargs, ) + return token_data.user, token return current_user_token_dependency @@ -106,6 +110,7 @@ def current_user( active: bool = False, verified: bool = False, superuser: bool = False, + fresh: bool = False, get_enabled_backends: Optional[EnabledBackendsDependency] = None, ): """ @@ -133,18 +138,68 @@ def current_user( @with_signature(signature) async def current_user_dependency(*args, **kwargs): - user, _ = await self._authenticate( + token_data, _ = await self._authenticate( *args, optional=optional, active=active, verified=verified, superuser=superuser, + fresh=fresh, **kwargs, ) - return user + if token_data: + return token_data.user + return None return current_user_dependency + def current_token( + self, + optional: bool = False, + active: bool = False, + verified: bool = False, + superuser: bool = False, + fresh: bool = False, + get_enabled_backends: Optional[EnabledBackendsDependency] = None, + ): + """ + Return a dependency callable to retrieve the full token data for the currently authenticated user. + + :param optional: If `True`, `None` is returned if there is no authenticated user + or if it doesn't pass the other requirements. + Otherwise, throw `401 Unauthorized`. Defaults to `False`. + Otherwise, an exception is raised. Defaults to `False`. + :param active: If `True`, throw `401 Unauthorized` if + the authenticated user is inactive. Defaults to `False`. + :param verified: If `True`, throw `401 Unauthorized` if + the authenticated user is not verified. Defaults to `False`. + :param superuser: If `True`, throw `403 Forbidden` if + the authenticated user is not a superuser. Defaults to `False`. + :param get_enabled_backends: Optional dependency callable returning + a list of enabled authentication backends. + Useful if you want to dynamically enable some authentication backends + based on external logic, like a configuration in database. + By default, all specified authentication backends are enabled. + Please not however that every backends will appear in the OpenAPI documentation, + as FastAPI resolves it statically. + """ + signature = self._get_dependency_signature(get_enabled_backends) + + @with_signature(signature) + async def current_token_dependency(*args, **kwargs): + token_data, _ = await self._authenticate( + *args, + optional=optional, + active=active, + verified=verified, + superuser=superuser, + fresh=fresh, + **kwargs, + ) + return token_data + + return current_token_dependency + async def _authenticate( self, *args, @@ -153,9 +208,10 @@ async def _authenticate( active: bool = False, verified: bool = False, superuser: bool = False, + fresh: bool = False, **kwargs, - ) -> Tuple[Optional[models.UP], Optional[str]]: - user: Optional[models.UP] = None + ) -> Tuple[Optional[UserTokenData[models.UP, models.ID]], Optional[str]]: + token_data: Optional[UserTokenData[models.UP, models.ID]] = None token: Optional[str] = None enabled_backends: Sequence[AuthenticationBackend] = kwargs.get( "enabled_backends", self.backends @@ -163,27 +219,30 @@ async def _authenticate( for backend in self.backends: if backend in enabled_backends: token = kwargs[name_to_variable_name(backend.name)] - strategy: Strategy[models.UP, models.ID] = kwargs[ + strategy: Strategy = kwargs[ name_to_strategy_variable_name(backend.name) ] if token is not None: - user = await strategy.read_token(token, user_manager) - if user: + token_data = await strategy.read_token(token, user_manager) + if token_data: break status_code = status.HTTP_401_UNAUTHORIZED - if user: - status_code = status.HTTP_403_FORBIDDEN - if active and not user.is_active: - status_code = status.HTTP_401_UNAUTHORIZED - user = None - elif ( - verified and not user.is_verified or superuser and not user.is_superuser - ): - user = None - if not user and not optional: + if token_data: + if token_data.user: + status_code = status.HTTP_403_FORBIDDEN + if active and not token_data.user.is_active: + status_code = status.HTTP_401_UNAUTHORIZED + token_data = None + elif ( + (verified and not token_data.user.is_verified) + or (superuser and not token_data.user.is_superuser) + or (fresh and not token_data.fresh) + ): + token_data = None + if not token_data and not optional: raise HTTPException(status_code=status_code) - return user, token + return token_data, token def _get_dependency_signature( self, get_enabled_backends: Optional[EnabledBackendsDependency] = None diff --git a/fastapi_users/authentication/backend.py b/fastapi_users/authentication/backend.py index b02ca98d..331b613f 100644 --- a/fastapi_users/authentication/backend.py +++ b/fastapi_users/authentication/backend.py @@ -1,4 +1,5 @@ -from typing import Any, Generic +from datetime import datetime +from typing import Any, Generic, Optional, Set from fastapi import Response @@ -7,14 +8,19 @@ Strategy, StrategyDestroyNotSupportedError, ) +from fastapi_users.authentication.token import UserTokenData from fastapi_users.authentication.transport import ( + LoginT, + LogoutT, Transport, TransportLogoutNotSupportedError, + TransportTokenResponse, ) +from fastapi_users.scopes import SystemScope from fastapi_users.types import DependencyCallable -class AuthenticationBackend(Generic[models.UP, models.ID]): +class AuthenticationBackend(Generic[LoginT, LogoutT]): """ Combination of an authentication transport and strategy. @@ -27,34 +33,67 @@ class AuthenticationBackend(Generic[models.UP, models.ID]): """ name: str - transport: Transport + transport: Transport[LoginT, LogoutT] def __init__( self, name: str, - transport: Transport, - get_strategy: DependencyCallable[Strategy[models.UP, models.ID]], + transport: Transport[LoginT, LogoutT], + get_strategy: DependencyCallable[Strategy], + access_token_lifetime_seconds: Optional[int] = 3600, + refresh_token_enabled: bool = False, + refresh_token_lifetime_seconds: Optional[int] = 86400, ): self.name = name self.transport = transport self.get_strategy = get_strategy + self.access_token_lifetime_seconds = access_token_lifetime_seconds + self.refresh_token_enabled = refresh_token_enabled + self.refresh_token_lifetime_seconds = refresh_token_lifetime_seconds async def login( self, - strategy: Strategy[models.UP, models.ID], - user: models.UP, + strategy: Strategy, + user: models.UserProtocol[Any], response: Response, - ) -> Any: - token = await strategy.write_token(user) - return await self.transport.get_login_response(token, response) + last_authenticated: Optional[datetime] = None, + ) -> Optional[LoginT]: + scopes: Set[str] = set() + if user.is_active: + scopes.add(SystemScope.USER) + if user.is_verified: + scopes.add(SystemScope.VERIFIED) + if user.is_superuser: + scopes.add(SystemScope.SUPERUSER) + + access_token_data = UserTokenData.issue_now( + user, + self.access_token_lifetime_seconds, + last_authenticated, + scopes=scopes, + ) + token_response = TransportTokenResponse( + access_token=await strategy.write_token(access_token_data) + ) + if self.refresh_token_enabled: + refresh_token_data = UserTokenData.issue_now( + user, + self.refresh_token_lifetime_seconds, + last_authenticated, + scopes={SystemScope.REFRESH}, + ) + token_response.refresh_token = await strategy.write_token( + refresh_token_data + ) + return await self.transport.get_login_response(token_response, response) async def logout( self, - strategy: Strategy[models.UP, models.ID], - user: models.UP, + strategy: Strategy, + user: models.UserProtocol[Any], token: str, response: Response, - ) -> Any: + ) -> Optional[LogoutT]: try: await strategy.destroy_token(token, user) except StrategyDestroyNotSupportedError: diff --git a/fastapi_users/authentication/strategy/base.py b/fastapi_users/authentication/strategy/base.py index ce60db13..e6bb0557 100644 --- a/fastapi_users/authentication/strategy/base.py +++ b/fastapi_users/authentication/strategy/base.py @@ -1,5 +1,5 @@ import sys -from typing import Generic, Optional +from typing import Any, Dict, Generic, Optional if sys.version_info < (3, 8): from typing_extensions import Protocol # pragma: no cover @@ -7,6 +7,7 @@ from typing import Protocol # pragma: no cover from fastapi_users import models +from fastapi_users.authentication.token import UserTokenData from fastapi_users.manager import BaseUserManager @@ -14,14 +15,23 @@ class StrategyDestroyNotSupportedError(Exception): pass -class Strategy(Protocol, Generic[models.UP, models.ID]): +class Strategy(Protocol): async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] - ) -> Optional[models.UP]: + self, + token: Optional[str], + user_manager: BaseUserManager[models.UP, models.ID], + ) -> Optional[UserTokenData[models.UP, models.ID]]: ... # pragma: no cover - async def write_token(self, user: models.UP) -> str: + async def write_token( + self, + token_data: UserTokenData[models.UserProtocol[Any], Any], + ) -> str: ... # pragma: no cover - async def destroy_token(self, token: str, user: models.UP) -> None: + async def destroy_token( + self, + token: str, + user: models.UserProtocol[Any], + ) -> None: ... # pragma: no cover diff --git a/fastapi_users/authentication/strategy/db/models.py b/fastapi_users/authentication/strategy/db/models.py index 6c3b58be..8456e0ae 100644 --- a/fastapi_users/authentication/strategy/db/models.py +++ b/fastapi_users/authentication/strategy/db/models.py @@ -1,6 +1,6 @@ import sys from datetime import datetime -from typing import TypeVar +from typing import Optional, TypeVar if sys.version_info < (3, 8): from typing_extensions import Protocol # pragma: no cover @@ -16,6 +16,9 @@ class AccessTokenProtocol(Protocol[models.ID]): token: str user_id: models.ID created_at: datetime + expires_at: Optional[datetime] + last_authenticated: datetime + scopes: str def __init__(self, *args, **kwargs) -> None: ... # pragma: no cover diff --git a/fastapi_users/authentication/strategy/db/strategy.py b/fastapi_users/authentication/strategy/db/strategy.py index d7c3c7a3..946b99de 100644 --- a/fastapi_users/authentication/strategy/db/strategy.py +++ b/fastapi_users/authentication/strategy/db/strategy.py @@ -6,50 +6,69 @@ from fastapi_users.authentication.strategy.base import Strategy from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase from fastapi_users.authentication.strategy.db.models import AP +from fastapi_users.authentication.token import TokenData, UserTokenData from fastapi_users.manager import BaseUserManager -class DatabaseStrategy( - Strategy[models.UP, models.ID], Generic[models.UP, models.ID, AP] -): - def __init__( - self, database: AccessTokenDatabase[AP], lifetime_seconds: Optional[int] = None - ): +class DatabaseStrategy(Strategy, Generic[AP]): + def __init__(self, database: AccessTokenDatabase[AP]): self.database = database - self.lifetime_seconds = lifetime_seconds async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] - ) -> Optional[models.UP]: + self, + token: Optional[str], + user_manager: BaseUserManager[models.UP, models.ID], + ) -> Optional[UserTokenData[models.UP, models.ID]]: + if token is None: return None - max_age = None - if self.lifetime_seconds: - max_age = datetime.now(timezone.utc) - timedelta( - seconds=self.lifetime_seconds - ) - - access_token = await self.database.get_by_token(token, max_age) + access_token = await self.database.get_by_token(token) if access_token is None: return None + token_data = TokenData( + user_id=access_token.user_id, + created_at=access_token.created_at, + expires_at=access_token.expires_at, + last_authenticated=access_token.last_authenticated, + scopes=( + set(access_token.scopes.split(" ")) if access_token.scopes else set() + ), + ) + + if token_data.expired: + return None + try: - parsed_id = user_manager.parse_id(access_token.user_id) - return await user_manager.get(parsed_id) + return await token_data.lookup_user(user_manager) except (exceptions.UserNotExists, exceptions.InvalidID): return None - async def write_token(self, user: models.UP) -> str: - access_token_dict = self._create_access_token_dict(user) + async def write_token( + self, + token_data: UserTokenData[models.UserProtocol[Any], Any], + ) -> str: + access_token_dict = self._create_access_token_dict(token_data) access_token = await self.database.create(access_token_dict) return access_token.token - async def destroy_token(self, token: str, user: models.UP) -> None: + async def destroy_token( + self, + token: str, + user: models.UserProtocol[Any], + ) -> None: access_token = await self.database.get_by_token(token) if access_token is not None: await self.database.delete(access_token) - def _create_access_token_dict(self, user: models.UP) -> Dict[str, Any]: + def _create_access_token_dict( + self, token_data: UserTokenData[models.UP, models.ID] + ) -> Dict[str, Any]: token = secrets.token_urlsafe() - return {"token": token, "user_id": user.id} + return { + "token": token, + "user_id": token_data.user.id, + "scopes": token_data.scope, + **token_data.dict(exclude={"user", "scopes"}), + } diff --git a/fastapi_users/authentication/strategy/jwt.py b/fastapi_users/authentication/strategy/jwt.py index 39866062..271b91f1 100644 --- a/fastapi_users/authentication/strategy/jwt.py +++ b/fastapi_users/authentication/strategy/jwt.py @@ -1,4 +1,5 @@ -from typing import Generic, List, Optional +from datetime import datetime, timezone +from typing import Any, List, Optional import jwt @@ -7,21 +8,20 @@ Strategy, StrategyDestroyNotSupportedError, ) -from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt +from fastapi_users.authentication.token import UserTokenData +from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt # type: ignore from fastapi_users.manager import BaseUserManager -class JWTStrategy(Strategy[models.UP, models.ID], Generic[models.UP, models.ID]): +class JWTStrategy(Strategy): def __init__( self, secret: SecretType, - lifetime_seconds: Optional[int], token_audience: List[str] = ["fastapi-users:auth"], algorithm: str = "HS256", public_key: Optional[SecretType] = None, ): self.secret = secret - self.lifetime_seconds = lifetime_seconds self.token_audience = token_audience self.algorithm = algorithm self.public_key = public_key @@ -35,8 +35,10 @@ def decode_key(self) -> SecretType: return self.public_key or self.secret async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] - ) -> Optional[models.UP]: + self, + token: Optional[str], + user_manager: BaseUserManager[models.UP, Any], + ) -> Optional[UserTokenData[models.UP, models.ID]]: if token is None: return None @@ -44,25 +46,56 @@ async def read_token( data = decode_jwt( token, self.decode_key, self.token_audience, algorithms=[self.algorithm] ) - user_id = data.get("user_id") - if user_id is None: - return None except jwt.PyJWTError: return None + if any(x not in data for x in ["sub", "iat", "auth_time"]): + return None + + user_id = data["sub"] try: parsed_id = user_manager.parse_id(user_id) - return await user_manager.get(parsed_id) + user = await user_manager.get(parsed_id) except (exceptions.UserNotExists, exceptions.InvalidID): return None - async def write_token(self, user: models.UP) -> str: - data = {"user_id": str(user.id), "aud": self.token_audience} - return generate_jwt( - data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm + if "exp" in data: + expires_at = datetime.fromtimestamp(data["exp"], tz=timezone.utc) + else: + expires_at = None + + scope = data["scope"] + + return UserTokenData( + user=user, + created_at=datetime.fromtimestamp(data["iat"], tz=timezone.utc), + expires_at=expires_at, + last_authenticated=datetime.fromtimestamp( + data["auth_time"], tz=timezone.utc + ), + scopes=set(scope.split(" ")) if scope else set(), ) - async def destroy_token(self, token: str, user: models.UP) -> None: + async def write_token( + self, + token_data: UserTokenData[models.UserProtocol[Any], Any], + ) -> str: + data = { + "sub": str(token_data.user.id), + "aud": self.token_audience, + "iat": int(token_data.created_at.timestamp()), + "scope": token_data.scope, + "auth_time": int(token_data.last_authenticated.timestamp()), + } + if token_data.expires_at: + data["exp"] = int(token_data.expires_at.timestamp()) + return generate_jwt(data, self.encode_key, algorithm=self.algorithm) + + async def destroy_token( + self, + token: str, + user: models.UserProtocol[Any], + ) -> None: raise StrategyDestroyNotSupportedError( "A JWT can't be invalidated: it's valid until it expires." ) diff --git a/fastapi_users/authentication/strategy/redis.py b/fastapi_users/authentication/strategy/redis.py index 58082867..21aecfd2 100644 --- a/fastapi_users/authentication/strategy/redis.py +++ b/fastapi_users/authentication/strategy/redis.py @@ -1,47 +1,71 @@ import secrets -from typing import Generic, Optional +from typing import Any, Optional +import pydantic import redis.asyncio from fastapi_users import exceptions, models from fastapi_users.authentication.strategy.base import Strategy +from fastapi_users.authentication.token import TokenData, UserTokenData from fastapi_users.manager import BaseUserManager -class RedisStrategy(Strategy[models.UP, models.ID], Generic[models.UP, models.ID]): +class RedisStrategy(Strategy): def __init__( self, redis: redis.asyncio.Redis, - lifetime_seconds: Optional[int] = None, *, key_prefix: str = "fastapi_users_token:", ): self.redis = redis - self.lifetime_seconds = lifetime_seconds self.key_prefix = key_prefix async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] - ) -> Optional[models.UP]: + self, + token: Optional[str], + user_manager: BaseUserManager[models.UP, models.ID], + ) -> Optional[UserTokenData[models.UP, models.ID]]: + if token is None: return None - user_id = await self.redis.get(f"{self.key_prefix}{token}") - if user_id is None: + token_value = await self.redis.get(f"{self.key_prefix}{token}") + if token_value is None: return None try: - parsed_id = user_manager.parse_id(user_id) - return await user_manager.get(parsed_id) + token_data = TokenData.parse_raw(token_value) + except pydantic.ValidationError: + return None + + if token_data is None: + return None + + try: + return await token_data.lookup_user(user_manager) except (exceptions.UserNotExists, exceptions.InvalidID): return None - async def write_token(self, user: models.UP) -> str: + async def write_token( + self, + token_data: UserTokenData[models.UserProtocol[Any], Any], + ) -> str: token = secrets.token_urlsafe() + expiry = ( + None + if not token_data.time_to_expiry + else int(token_data.time_to_expiry.total_seconds()) + ) await self.redis.set( - f"{self.key_prefix}{token}", str(user.id), ex=self.lifetime_seconds + f"{self.key_prefix}{token}", + token_data.json(), + ex=expiry, ) return token - async def destroy_token(self, token: str, user: models.UP) -> None: + async def destroy_token( + self, + token: str, + user: models.UserProtocol[Any], + ) -> None: await self.redis.delete(f"{self.key_prefix}{token}") diff --git a/fastapi_users/authentication/token.py b/fastapi_users/authentication/token.py new file mode 100644 index 00000000..1564c377 --- /dev/null +++ b/fastapi_users/authentication/token.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import re +from datetime import datetime, timedelta, timezone +from typing import Any, Generic, Iterable, Optional, Set + +from pydantic import BaseModel, Field, validator + +from fastapi_users import models +from fastapi_users.manager import BaseUserManager +from fastapi_users.scopes import Scope, SystemScope + + +class TokenData(BaseModel): + user_id: Any + created_at: datetime + expires_at: Optional[datetime] + last_authenticated: datetime + scopes: Set[Scope] + + @validator("scopes") + def _validate_scopes(cls, v: Set[Scope]) -> Set[Scope]: + def _validate_scope(scope: Scope) -> Scope: + if isinstance(scope, SystemScope): + return scope + # As we are going to use these for OAuth 2.0 tokens, + # all scopes should be valid OAuth 2.0 scopes - see + # https://www.rfc-editor.org/rfc/rfc6749#section-3.3 + assert re.match(r"^[!#-\[\]-~]+$", scope) + try: + return SystemScope(scope) + except ValueError: + return scope + + return set(_validate_scope(x) for x in v) + + @property + def scope(self) -> str: + return " ".join(str(x) for x in self.scopes) + + @property + def fresh(self) -> bool: + return self.created_at == self.last_authenticated + + @property + def expired(self) -> bool: + if self.expires_at is None: + return False + return self.expires_at <= datetime.now(timezone.utc) + + @property + def time_to_expiry(self) -> Optional[timedelta]: + if self.expires_at is None: + return None + return self.expires_at - datetime.now(timezone.utc) + + async def lookup_user( + self, user_manager: BaseUserManager[models.UP, models.ID] + ) -> UserTokenData[models.UP, models.ID]: + user_id = user_manager.parse_id(self.user_id) + return UserTokenData( + user_id=user_id, + **self.dict(exclude={"user_id"}), + user=await user_manager.get(user_id), + ) + + +class UserTokenData(TokenData, Generic[models.UP, models.ID]): + user: models.UP = Field(..., exclude=True) + + @classmethod + def issue_now( + cls, + user: models.UserProtocol[models.ID], + lifetime_seconds: Optional[int] = None, + last_authenticated: Optional[datetime] = None, + scopes: Optional[Iterable[Scope]] = None, + ) -> UserTokenData[models.UP, models.ID]: + + scopes = scopes or set() + + now = datetime.now(timezone.utc) + + if lifetime_seconds is None: + expires_at = None + else: + expires_at = now + timedelta(seconds=lifetime_seconds) + + return cls( + user_id=user.id, + created_at=now, + expires_at=expires_at, + last_authenticated=last_authenticated or now, + scopes=set(scopes), + user=user, + ) + + class Config: + arbitrary_types_allowed = True diff --git a/fastapi_users/authentication/transport/__init__.py b/fastapi_users/authentication/transport/__init__.py index b9eb5579..668aceb2 100644 --- a/fastapi_users/authentication/transport/__init__.py +++ b/fastapi_users/authentication/transport/__init__.py @@ -1,6 +1,9 @@ from fastapi_users.authentication.transport.base import ( + LoginT, + LogoutT, Transport, TransportLogoutNotSupportedError, + TransportTokenResponse, ) from fastapi_users.authentication.transport.bearer import BearerTransport from fastapi_users.authentication.transport.cookie import CookieTransport @@ -9,5 +12,8 @@ "BearerTransport", "CookieTransport", "Transport", + "TransportTokenResponse", "TransportLogoutNotSupportedError", + "LoginT", + "LogoutT", ] diff --git a/fastapi_users/authentication/transport/base.py b/fastapi_users/authentication/transport/base.py index d54c3a5a..3de855ae 100644 --- a/fastapi_users/authentication/transport/base.py +++ b/fastapi_users/authentication/transport/base.py @@ -1,5 +1,7 @@ import sys -from typing import Any +from typing import Optional, Type, TypeVar + +from pydantic import BaseModel if sys.version_info < (3, 8): from typing_extensions import Protocol # pragma: no cover @@ -16,13 +18,26 @@ class TransportLogoutNotSupportedError(Exception): pass -class Transport(Protocol): +class TransportTokenResponse(BaseModel): + access_token: str + refresh_token: Optional[str] = None + + +LoginT = TypeVar("LoginT") +LogoutT = TypeVar("LogoutT") + + +class Transport(Protocol[LoginT, LogoutT]): + login_response_model: Optional[Type[LoginT]] = None + logout_response_model: Optional[Type[LogoutT]] = None scheme: SecurityBase - async def get_login_response(self, token: str, response: Response) -> Any: + async def get_login_response( + self, token: TransportTokenResponse, response: Response + ) -> LoginT: ... # pragma: no cover - async def get_logout_response(self, response: Response) -> Any: + async def get_logout_response(self, response: Response) -> LogoutT: ... # pragma: no cover @staticmethod diff --git a/fastapi_users/authentication/transport/bearer.py b/fastapi_users/authentication/transport/bearer.py index 924fe9f8..b4258935 100644 --- a/fastapi_users/authentication/transport/bearer.py +++ b/fastapi_users/authentication/transport/bearer.py @@ -1,5 +1,3 @@ -from typing import Any - from fastapi import Response, status from fastapi.security import OAuth2PasswordBearer from pydantic import BaseModel @@ -7,25 +5,29 @@ from fastapi_users.authentication.transport.base import ( Transport, TransportLogoutNotSupportedError, + TransportTokenResponse, ) from fastapi_users.openapi import OpenAPIResponseType -class BearerResponse(BaseModel): - access_token: str +class BearerResponse(TransportTokenResponse): token_type: str -class BearerTransport(Transport): +class BearerTransport(Transport[BearerResponse, None]): + login_response_model = BearerResponse + logout_response_model = None scheme: OAuth2PasswordBearer def __init__(self, tokenUrl: str): self.scheme = OAuth2PasswordBearer(tokenUrl, auto_error=False) - async def get_login_response(self, token: str, response: Response) -> Any: - return BearerResponse(access_token=token, token_type="bearer") + async def get_login_response( + self, token: TransportTokenResponse, response: Response + ) -> BearerResponse: + return BearerResponse(**token.dict(), token_type="bearer") - async def get_logout_response(self, response: Response) -> Any: + async def get_logout_response(self, response: Response) -> None: raise TransportLogoutNotSupportedError() @staticmethod diff --git a/fastapi_users/authentication/transport/cookie.py b/fastapi_users/authentication/transport/cookie.py index 6fa8e198..9b24a381 100644 --- a/fastapi_users/authentication/transport/cookie.py +++ b/fastapi_users/authentication/transport/cookie.py @@ -1,13 +1,18 @@ -from typing import Any, Optional +from typing import Optional from fastapi import Response, status from fastapi.security import APIKeyCookie -from fastapi_users.authentication.transport.base import Transport +from fastapi_users.authentication.transport.base import ( + Transport, + TransportTokenResponse, +) from fastapi_users.openapi import OpenAPIResponseType -class CookieTransport(Transport): +class CookieTransport(Transport[None, None]): + login_response_model = None + logout_response_model = None scheme: APIKeyCookie def __init__( @@ -29,10 +34,17 @@ def __init__( self.cookie_samesite = cookie_samesite self.scheme = APIKeyCookie(name=self.cookie_name, auto_error=False) - async def get_login_response(self, token: str, response: Response) -> Any: + async def get_login_response( + self, token: TransportTokenResponse, response: Response + ) -> None: + if token.refresh_token: + raise NotImplementedError( + "Refresh tokens not yet supported by cookie transport" + ) + response.set_cookie( self.cookie_name, - token, + token.access_token, max_age=self.cookie_max_age, path=self.cookie_path, domain=self.cookie_domain, @@ -45,7 +57,7 @@ async def get_login_response(self, token: str, response: Response) -> Any: # so that FastAPI can terminate it properly return None - async def get_logout_response(self, response: Response) -> Any: + async def get_logout_response(self, response: Response) -> None: response.set_cookie( self.cookie_name, "", diff --git a/fastapi_users/router/__init__.py b/fastapi_users/router/__init__.py index 31bf68be..219351da 100644 --- a/fastapi_users/router/__init__.py +++ b/fastapi_users/router/__init__.py @@ -1,5 +1,6 @@ from fastapi_users.router.auth import get_auth_router from fastapi_users.router.common import ErrorCode +from fastapi_users.router.refresh import get_refresh_router from fastapi_users.router.register import get_register_router from fastapi_users.router.reset import get_reset_password_router from fastapi_users.router.users import get_users_router @@ -8,6 +9,7 @@ __all__ = [ "ErrorCode", "get_auth_router", + "get_refresh_router", "get_register_router", "get_reset_password_router", "get_users_router", diff --git a/fastapi_users/router/auth.py b/fastapi_users/router/auth.py index cda0fa3e..f309445a 100644 --- a/fastapi_users/router/auth.py +++ b/fastapi_users/router/auth.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Optional, Tuple from fastapi import APIRouter, Depends, HTTPException, Response, status from fastapi.security import OAuth2PasswordRequestForm @@ -11,7 +11,7 @@ def get_auth_router( - backend: AuthenticationBackend, + backend: AuthenticationBackend[models.UP, models.ID], get_user_manager: UserManagerDependency[models.UP, models.ID], authenticator: Authenticator, requires_verification: bool = False, @@ -46,13 +46,15 @@ def get_auth_router( @router.post( "/login", name=f"auth:{backend.name}.login", + response_model=backend.transport.login_response_model, + response_model_exclude_none=True, responses=login_responses, ) - async def login( + async def login( # type: ignore response: Response, credentials: OAuth2PasswordRequestForm = Depends(), user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), - strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), + strategy: Strategy = Depends(backend.get_strategy), ): user = await user_manager.authenticate(credentials) @@ -78,12 +80,15 @@ async def login( } @router.post( - "/logout", name=f"auth:{backend.name}.logout", responses=logout_responses + "/logout", + name=f"auth:{backend.name}.logout", + response_model=backend.transport.logout_response_model, + responses=logout_responses, ) - async def logout( + async def logout( # type: ignore response: Response, user_token: Tuple[models.UP, str] = Depends(get_current_user_token), - strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), + strategy: Strategy = Depends(backend.get_strategy), ): user, token = user_token return await backend.logout(strategy, user, token, response) diff --git a/fastapi_users/router/oauth.py b/fastapi_users/router/oauth.py index 16e96d00..b5a14265 100644 --- a/fastapi_users/router/oauth.py +++ b/fastapi_users/router/oauth.py @@ -8,6 +8,7 @@ from fastapi_users import models, schemas from fastapi_users.authentication import AuthenticationBackend, Authenticator, Strategy +from fastapi_users.authentication.token import UserTokenData from fastapi_users.exceptions import UserAlreadyExists from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt from fastapi_users.manager import BaseUserManager, UserManagerDependency @@ -77,6 +78,8 @@ async def authorize( "/callback", name=callback_route_name, description="The response varies based on the authentication backend used.", + response_model=backend.transport.login_response_model, + response_model_exclude_none=True, responses={ status.HTTP_400_BAD_REQUEST: { "model": ErrorModel, @@ -104,7 +107,7 @@ async def callback( oauth2_authorize_callback ), user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), - strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), + strategy: Strategy = Depends(backend.get_strategy), ): token, state = access_token_state account_id, account_email = await oauth_client.get_id_email( diff --git a/fastapi_users/router/refresh.py b/fastapi_users/router/refresh.py new file mode 100644 index 00000000..4bfe9d5c --- /dev/null +++ b/fastapi_users/router/refresh.py @@ -0,0 +1,91 @@ +from fastapi import APIRouter, Depends, Form, HTTPException, Response, status + +from fastapi_users import models +from fastapi_users.authentication import AuthenticationBackend, Strategy +from fastapi_users.manager import BaseUserManager, UserManagerDependency +from fastapi_users.openapi import OpenAPIResponseType +from fastapi_users.router.common import ErrorCode, ErrorModel +from fastapi_users.scopes import SystemScope + + +class OAuth2RefreshTokenForm(object): + def __init__( + self, + grant_type: str = Form( + default="refresh_token", regex="refresh_token", example="refresh_token" + ), + refresh_token: str = Form(...), + scope: str = Form(""), + ): + self.grant_type = grant_type + self.refresh_token = refresh_token + self.scopes = scope.split() + + +def get_refresh_router( + backend: AuthenticationBackend[models.UP, models.ID], + get_user_manager: UserManagerDependency[models.UP, models.ID], +) -> APIRouter: + """Generate a router with login/logout routes for an authentication backend.""" + router = APIRouter() + + login_responses: OpenAPIResponseType = { + status.HTTP_400_BAD_REQUEST: { + "model": ErrorModel, + "content": { + "application/json": { + "examples": { + ErrorCode.LOGIN_BAD_CREDENTIALS: { + "summary": "Bad credentials or the user is inactive.", + "value": {"detail": ErrorCode.LOGIN_BAD_CREDENTIALS}, + }, + ErrorCode.LOGIN_USER_NOT_VERIFIED: { + "summary": "The user is not verified.", + "value": {"detail": ErrorCode.LOGIN_USER_NOT_VERIFIED}, + }, + } + } + }, + }, + **backend.transport.get_openapi_login_responses_success(), + } + + @router.post( + "/refresh", + name=f"auth:{backend.name}.refresh", + responses=login_responses, + response_model=backend.transport.login_response_model, + response_model_exclude_none=True, + ) + async def refresh( # type: ignore + response: Response, + form_data: OAuth2RefreshTokenForm = Depends(), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), + strategy: Strategy = Depends(backend.get_strategy), + ): + if not backend.refresh_token_enabled: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="refresh_tokens_not_allowed", + ) + + token_data = await strategy.read_token(form_data.refresh_token, user_manager) + if token_data: + if ( + token_data.user + and SystemScope.REFRESH in token_data.scopes + and not token_data.expired + ): + return await backend.login( + strategy, + token_data.user, + response, + token_data.last_authenticated, + ) + + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="invalid_grant", + ) + + return router diff --git a/fastapi_users/router/users.py b/fastapi_users/router/users.py index 8e048abf..a976651c 100644 --- a/fastapi_users/router/users.py +++ b/fastapi_users/router/users.py @@ -4,6 +4,7 @@ from fastapi_users import exceptions, models, schemas from fastapi_users.authentication import Authenticator +from fastapi_users.authentication.token import UserTokenData from fastapi_users.manager import BaseUserManager, UserManagerDependency from fastapi_users.router.common import ErrorCode, ErrorModel diff --git a/fastapi_users/scopes.py b/fastapi_users/scopes.py new file mode 100644 index 00000000..01bf5da5 --- /dev/null +++ b/fastapi_users/scopes.py @@ -0,0 +1,16 @@ +from enum import Enum +from typing import Union + + +class SystemScope(str, Enum): + USER = "fastapi-users:user" + SUPERUSER = "fastapi-users:superuser" + VERIFIED = "fastapi-users:verified" + REFRESH = "fastapi-users:refresh" + + def __str__(self) -> str: + return self.value + + +UserDefinedScope = str +Scope = Union[SystemScope, UserDefinedScope] diff --git a/pyproject.toml b/pyproject.toml index af2d1c6b..17975504 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,6 +97,7 @@ dev = [ "asgi_lifespan", "uvicorn", "types-redis", + "pytest-freezegun", ] sqlalchemy = [ "fastapi-users-db-sqlalchemy >=4.0.0", @@ -112,4 +113,4 @@ redis = [ ] [project.urls] -Documentation = "https://fastapi-users.github.io/fastapi-users/" +Documentation = "https://fastapi-users.github.io/fastapi-users/" \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 9f0c3497..917fe40f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import asyncio import dataclasses import uuid +from datetime import datetime, timedelta, timezone from typing import ( Any, AsyncGenerator, @@ -9,27 +10,31 @@ Generic, List, Optional, + Set, Type, Union, ) from unittest.mock import MagicMock import httpx +import pydantic import pytest from asgi_lifespan import LifespanManager from fastapi import FastAPI, Response from httpx_oauth.oauth2 import OAuth2 -from pydantic import UUID4, SecretStr from pytest_mock import MockerFixture from fastapi_users import exceptions, models, schemas from fastapi_users.authentication import AuthenticationBackend, BearerTransport from fastapi_users.authentication.strategy import Strategy +from fastapi_users.authentication.token import TokenData, UserTokenData +from fastapi_users.authentication.transport.bearer import BearerResponse from fastapi_users.db import BaseUserDatabase from fastapi_users.jwt import SecretType from fastapi_users.manager import BaseUserManager, UUIDIDMixin from fastapi_users.openapi import OpenAPIResponseType from fastapi_users.password import PasswordHelper +from fastapi_users.scopes import Scope, SystemScope password_helper = PasswordHelper() guinevere_password_hash = password_helper.hash("guinevere") @@ -156,7 +161,7 @@ def _async_method_mocker( return _async_method_mocker -@pytest.fixture(params=["SECRET", SecretStr("SECRET")]) +@pytest.fixture(params=["SECRET", pydantic.SecretStr("SECRET")]) def secret(request) -> SecretType: return request.param @@ -323,7 +328,7 @@ def mock_user_db( verified_superuser: UserModel, ) -> BaseUserDatabase[UserModel, IDType]: class MockUserDatabase(BaseUserDatabase[UserModel, IDType]): - async def get(self, id: UUID4) -> Optional[UserModel]: + async def get(self, id: pydantic.UUID4) -> Optional[UserModel]: if id == user.id: return user if id == verified_user.id: @@ -375,7 +380,7 @@ def mock_user_db_oauth( verified_superuser_oauth: UserOAuthModel, ) -> BaseUserDatabase[UserOAuthModel, IDType]: class MockUserDatabase(BaseUserDatabase[UserOAuthModel, IDType]): - async def get(self, id: UUID4) -> Optional[UserOAuthModel]: + async def get(self, id: pydantic.UUID4) -> Optional[UserOAuthModel]: if id == user_oauth.id: return user_oauth if id == verified_user_oauth.id: @@ -523,36 +528,149 @@ def get_openapi_logout_responses_success() -> OpenAPIResponseType: return {} -class MockStrategy(Strategy[UserModel, IDType]): +MockBackend = AuthenticationBackend[BearerResponse, None] + + +class MockStrategy(Strategy): + def __init__(self, token_type: str = "access"): + self.token_type = token_type + async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[UserModel, IDType] - ) -> Optional[UserModel]: - if token is not None: - try: - parsed_id = user_manager.parse_id(token) - return await user_manager.get(parsed_id) - except (exceptions.InvalidID, exceptions.UserNotExists): - return None - return None + self, + token: Optional[str], + user_manager: BaseUserManager[UserModel, IDType], + ) -> Optional[UserTokenData[UserModel, IDType]]: + if token is None: + return None + + try: + token_data = TokenData.parse_raw(token) + except pydantic.ValidationError: + return None + + if token_data is None: + return None + + try: + return await token_data.lookup_user(user_manager) + except (exceptions.UserNotExists, exceptions.InvalidID): + return None - async def write_token(self, user: UserModel) -> str: - return str(user.id) + async def write_token( + self, + token_data: UserTokenData[models.UserProtocol[Any], Any], + ) -> str: + return token_data.json() async def destroy_token(self, token: str, user: UserModel) -> None: return None -def get_mock_authentication(name: str): +def mock_valid_access_token(user: UserModel) -> str: + token_data = UserTokenData.issue_now(user, scopes=[SystemScope.USER]) + return token_data.json() + + +def mock_valid_refresh_token(user: UserModel) -> str: + token_data = UserTokenData.issue_now(user, scopes=[SystemScope.REFRESH]) + return token_data.json() + + +def mock_authorized_headers(user: UserModel) -> Dict[str, str]: + return {"Authorization": f"Bearer {mock_valid_access_token(user)}"} + + +def assert_valid_token_response( + token_response: Dict[str, str], + expected_access_token: TokenData, + expected_refresh_token: Optional[TokenData] = None, +) -> None: + assert isinstance(token_response, dict) + assert set(token_response.keys()).issubset( + {"token_type", "access_token", "refresh_token"} + ) + assert "token_type" in token_response + assert "access_token" in token_response + assert token_response["access_token"] == expected_access_token.json() + if expected_refresh_token: + assert "refresh_token" in token_response + assert token_response["refresh_token"] == expected_refresh_token.json() + else: + assert "refresh_token" not in token_response + + +def mock_token_data( + user_id: Any, + scopes: Set[Scope], + created_at: Optional[datetime] = None, + last_authenticated: Optional[datetime] = None, + lifetime_seconds: Optional[int] = None, +) -> TokenData: + + now = datetime.now(timezone.utc) + + created_at = created_at or now + last_authenticated = last_authenticated or now + + if lifetime_seconds: + expires_at = now + timedelta(seconds=lifetime_seconds) + else: + expires_at = None + + return TokenData( + user_id=user_id, + created_at=created_at, + expires_at=expires_at, + last_authenticated=last_authenticated, + scopes=scopes, + ) + + +def get_mock_authentication( + name: str, + access_token_lifetime_seconds: Optional[int] = 3600, + refresh_token_enabled: bool = False, + refresh_token_lifetime_seconds: Optional[int] = 86400, +) -> MockBackend: return AuthenticationBackend( name=name, transport=MockTransport(tokenUrl="/login"), get_strategy=lambda: MockStrategy(), + access_token_lifetime_seconds=access_token_lifetime_seconds, + refresh_token_enabled=refresh_token_enabled, + refresh_token_lifetime_seconds=refresh_token_lifetime_seconds, ) @pytest.fixture -def mock_authentication(): - return get_mock_authentication(name="mock") +def mock_authentication_factory( + access_token_lifetime_seconds: Optional[int], + refresh_token_enabled: bool, + refresh_token_lifetime_seconds: Optional[int], +) -> Callable[[str], MockBackend]: + def _mock_authentication_factory(name: str): + return get_mock_authentication( + name=name, + access_token_lifetime_seconds=access_token_lifetime_seconds, + refresh_token_enabled=refresh_token_enabled, + refresh_token_lifetime_seconds=refresh_token_lifetime_seconds, + ) + + return _mock_authentication_factory + + +@pytest.fixture +def mock_authentication( + access_token_lifetime_seconds: int, + refresh_token_enabled: bool, + refresh_token_lifetime_seconds: int, +) -> MockBackend: + return get_mock_authentication( + name="mock", + access_token_lifetime_seconds=access_token_lifetime_seconds, + refresh_token_enabled=refresh_token_enabled, + refresh_token_lifetime_seconds=refresh_token_lifetime_seconds, + ) @pytest.fixture @@ -581,3 +699,65 @@ def oauth_client() -> OAuth2: ACCESS_TOKEN_ENDPOINT, name="service1", ) + + +@pytest.fixture(params=[True]) +def token_fresh(request: pytest.FixtureRequest) -> bool: + return request.param # type: ignore + + +@pytest.fixture(params=[False]) +def token_expired(request: pytest.FixtureRequest) -> bool: + return request.param # type: ignore + + +@pytest.fixture(params=[{SystemScope.USER}]) +def scopes(request: pytest.FixtureRequest) -> Set[Scope]: + return request.param # type: ignore + + +@pytest.fixture +def token_data( + user: UserModel, + token_fresh: bool, + token_expired: bool, + scopes: Set[Scope], +) -> UserTokenData[UserModel, IDType]: + + now = datetime.now(timezone.utc) + + if token_expired: + expires_at = now - timedelta(minutes=30) + else: + expires_at = now + timedelta(minutes=30) + + created_at = expires_at - timedelta(hours=1) + + if token_fresh: + last_authenticated = created_at + else: + last_authenticated = created_at - timedelta(days=1) + + return UserTokenData( + user_id=user.id, + created_at=created_at, + expires_at=expires_at, + last_authenticated=last_authenticated, + scopes=scopes, + user=user, + ) + + +@pytest.fixture(params=[False]) +def refresh_token_enabled(request: pytest.FixtureRequest) -> bool: + return getattr(request, "param") + + +@pytest.fixture(params=[3600]) +def access_token_lifetime_seconds(request: pytest.FixtureRequest) -> Optional[int]: + return getattr(request, "param") + + +@pytest.fixture(params=[86400]) +def refresh_token_lifetime_seconds(request: pytest.FixtureRequest) -> Optional[int]: + return getattr(request, "param") diff --git a/tests/test_authentication_authenticator.py b/tests/test_authentication_authenticator.py index d2ee48e2..2a5af3c8 100644 --- a/tests/test_authentication_authenticator.py +++ b/tests/test_authentication_authenticator.py @@ -1,4 +1,4 @@ -from typing import AsyncGenerator, Generic, List, Optional, Sequence +from typing import AsyncGenerator, Generic, List, Optional, Sequence, Tuple import httpx import pytest @@ -9,10 +9,11 @@ from fastapi_users.authentication import AuthenticationBackend, Authenticator from fastapi_users.authentication.authenticator import DuplicateBackendNamesError from fastapi_users.authentication.strategy import Strategy +from fastapi_users.authentication.token import TokenData, UserTokenData from fastapi_users.authentication.transport import Transport from fastapi_users.manager import BaseUserManager from fastapi_users.types import DependencyCallable -from tests.conftest import User, UserModel +from tests.conftest import IDType, User, UserModel class MockSecurityScheme(SecurityBase): @@ -29,19 +30,23 @@ def __init__(self): class NoneStrategy(Strategy): async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] - ) -> Optional[models.UP]: + self, + token: Optional[str], + user_manager: BaseUserManager[models.UP, models.ID], + ) -> Optional[UserTokenData[models.UP, models.ID]]: return None -class UserStrategy(Strategy, Generic[models.UP]): - def __init__(self, user: models.UP): - self.user = user +class UserStrategy(Strategy, Generic[models.UP, models.ID]): + def __init__(self, token_data: UserTokenData[models.UP, models.ID]): + self.token_data = token_data async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] - ) -> Optional[models.UP]: - return self.user + self, + token: Optional[str], + user_manager: BaseUserManager[models.UP, models.ID], + ) -> Optional[UserTokenData[models.UP, models.ID]]: + return self.token_data @pytest.fixture @@ -55,20 +60,41 @@ def _get_backend_none(name: str = "none"): @pytest.fixture -def get_backend_user(user: UserModel): +def get_backend_user(token_data: UserTokenData[UserModel, IDType]): def _get_backend_user(name: str = "user"): return AuthenticationBackend( name=name, transport=MockTransport(), - get_strategy=lambda: UserStrategy(user), + get_strategy=lambda: UserStrategy(token_data), ) return _get_backend_user +@pytest.fixture(params=[False]) +def require_active(request: pytest.FixtureRequest) -> bool: + return getattr(request, "param") + + +@pytest.fixture(params=[False]) +def require_superuser(request: pytest.FixtureRequest) -> bool: + return getattr(request, "param") + + +@pytest.fixture(params=[False]) +def require_fresh(request: pytest.FixtureRequest) -> bool: + return getattr(request, "param") + + @pytest.fixture @pytest.mark.asyncio -def get_test_auth_client(get_user_manager, get_test_client): +def get_test_auth_client( + get_user_manager, + get_test_client, + require_active: bool, + require_superuser: bool, + require_fresh: bool, +): async def _get_test_auth_client( backends: List[AuthenticationBackend], get_enabled_backends: Optional[ @@ -81,32 +107,43 @@ async def _get_test_auth_client( @app.get("/test-current-user", response_model=User) def test_current_user( user: UserModel = Depends( - authenticator.current_user(get_enabled_backends=get_enabled_backends) + authenticator.current_user( + active=require_active, + superuser=require_superuser, + fresh=require_fresh, + get_enabled_backends=get_enabled_backends, + ) ), ): return user - @app.get("/test-current-active-user", response_model=User) - def test_current_active_user( - user: UserModel = Depends( - authenticator.current_user( - active=True, get_enabled_backends=get_enabled_backends + @app.get("/test-current-user-token", response_model=User) + def test_current_token( + user_token: Tuple[UserModel, str] = Depends( + authenticator.current_user_token( + active=require_active, + superuser=require_superuser, + fresh=require_fresh, + get_enabled_backends=get_enabled_backends, ) ), ): + user, token = user_token + assert token return user - @app.get("/test-current-superuser", response_model=User) - def test_current_superuser( - user: UserModel = Depends( - authenticator.current_user( - active=True, - superuser=True, + @app.get("/test-current-token", response_model=TokenData) + def test_current_token( + token_data: UserTokenData[UserModel, IDType] = Depends( + authenticator.current_token( + active=require_active, + superuser=require_superuser, + fresh=require_fresh, get_enabled_backends=get_enabled_backends, ) ), ): - return user + return TokenData(**token_data.dict()) async for client in get_test_client(app): yield client @@ -116,26 +153,58 @@ def test_current_superuser( @pytest.mark.authentication @pytest.mark.asyncio -async def test_authenticator(get_test_auth_client, get_backend_none, get_backend_user): +@pytest.mark.parametrize("require_active", [False], indirect=True) +@pytest.mark.parametrize("require_superuser", [False], indirect=True) +@pytest.mark.parametrize( + "path", + [ + "/test-current-user", + "/test-current-user-token", + "/test-current-token", + ], +) +async def test_authenticator( + get_test_auth_client, get_backend_none, get_backend_user, path: str +): async for client in get_test_auth_client([get_backend_none(), get_backend_user()]): - response = await client.get("/test-current-user") + response = await client.get(path) assert response.status_code == status.HTTP_200_OK @pytest.mark.authentication @pytest.mark.asyncio -async def test_authenticator_none(get_test_auth_client, get_backend_none): +@pytest.mark.parametrize("require_active", [False], indirect=True) +@pytest.mark.parametrize("require_superuser", [False], indirect=True) +@pytest.mark.parametrize( + "path", + [ + "/test-current-user", + "/test-current-user-token", + "/test-current-token", + ], +) +async def test_authenticator_none(get_test_auth_client, get_backend_none, path: str): async for client in get_test_auth_client( [get_backend_none(), get_backend_none(name="none-bis")] ): - response = await client.get("/test-current-user") + response = await client.get(path) assert response.status_code == status.HTTP_401_UNAUTHORIZED @pytest.mark.authentication @pytest.mark.asyncio +@pytest.mark.parametrize("require_active", [False], indirect=True) +@pytest.mark.parametrize("require_superuser", [False], indirect=True) +@pytest.mark.parametrize( + "path", + [ + "/test-current-user", + "/test-current-user-token", + "/test-current-token", + ], +) async def test_authenticator_none_enabled( - get_test_auth_client, get_backend_none, get_backend_user + get_test_auth_client, get_backend_none, get_backend_user, path: str ): backend_none = get_backend_none() backend_user = get_backend_user() @@ -146,7 +215,7 @@ async def get_enabled_backends(): async for client in get_test_auth_client( [backend_none, backend_user], get_enabled_backends ): - response = await client.get("/test-current-user") + response = await client.get(path) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -156,3 +225,42 @@ async def test_authenticators_with_same_name(get_test_auth_client, get_backend_n with pytest.raises(DuplicateBackendNamesError): async for _ in get_test_auth_client([get_backend_none(), get_backend_none()]): pass + + +@pytest.fixture(params=[status.HTTP_200_OK]) +def expected_http_status(request: pytest.FixtureRequest) -> int: + return getattr(request, "param") + + +@pytest.mark.authentication +@pytest.mark.asyncio +@pytest.mark.parametrize("require_active", [False], indirect=True) +@pytest.mark.parametrize("require_superuser", [False], indirect=True) +@pytest.mark.parametrize( + ("require_fresh", "token_fresh", "expected_http_status"), + [ + (False, False, status.HTTP_200_OK), + (False, True, status.HTTP_200_OK), + (True, False, status.HTTP_403_FORBIDDEN), + (True, True, status.HTTP_200_OK), + ], + indirect=True, +) +@pytest.mark.parametrize( + "path", + [ + "/test-current-user", + "/test-current-user-token", + "/test-current-token", + ], +) +async def test_freshness( + get_test_auth_client, + get_backend_none, + get_backend_user, + path: str, + expected_http_status: int, +): + async for client in get_test_auth_client([get_backend_none(), get_backend_user()]): + response = await client.get(path) + assert response.status_code == expected_http_status diff --git a/tests/test_authentication_strategy_db.py b/tests/test_authentication_strategy_db.py index 201f8949..b993e1f8 100644 --- a/tests/test_authentication_strategy_db.py +++ b/tests/test_authentication_strategy_db.py @@ -10,7 +10,8 @@ AccessTokenProtocol, DatabaseStrategy, ) -from tests.conftest import IDType, UserModel +from fastapi_users.authentication.token import UserTokenData +from tests.conftest import IDType, UserManager, UserModel @dataclasses.dataclass @@ -21,6 +22,11 @@ class AccessTokenModel(AccessTokenProtocol[IDType]): created_at: datetime = dataclasses.field( default_factory=lambda: datetime.now(timezone.utc) ) + expires_at: Optional[datetime] = None + last_authenticated: datetime = dataclasses.field( + default_factory=lambda: datetime.now(timezone.utc) + ) + scopes: str = "" class AccessTokenDatabaseMock(AccessTokenDatabase[AccessTokenModel]): @@ -60,6 +66,18 @@ async def delete(self, access_token: AccessTokenModel) -> None: pass +def assert_valid_token_model( + access_token: Optional[AccessTokenModel], + token_data: UserTokenData[UserModel, IDType], +): + assert access_token is not None + assert access_token.user_id == token_data.user.id + assert access_token.created_at == token_data.created_at + assert access_token.expires_at == token_data.expires_at + assert access_token.last_authenticated == token_data.last_authenticated + assert access_token.scopes == token_data.scope + + @pytest.fixture def access_token_database() -> AccessTokenDatabaseMock: return AccessTokenDatabaseMock() @@ -67,7 +85,7 @@ def access_token_database() -> AccessTokenDatabaseMock: @pytest.fixture def database_strategy(access_token_database: AccessTokenDatabaseMock): - return DatabaseStrategy(access_token_database, 3600) + return DatabaseStrategy(access_token_database) @pytest.mark.authentication @@ -75,8 +93,8 @@ class TestReadToken: @pytest.mark.asyncio async def test_missing_token( self, - database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel], - user_manager, + database_strategy: DatabaseStrategy[AccessTokenModel], + user_manager: UserManager, ): authenticated_user = await database_strategy.read_token(None, user_manager) assert authenticated_user is None @@ -84,8 +102,8 @@ async def test_missing_token( @pytest.mark.asyncio async def test_invalid_token( self, - database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel], - user_manager, + database_strategy: DatabaseStrategy[AccessTokenModel], + user_manager: UserManager, ): authenticated_user = await database_strategy.read_token("TOKEN", user_manager) assert authenticated_user is None @@ -93,51 +111,80 @@ async def test_invalid_token( @pytest.mark.asyncio async def test_valid_token_not_existing_user( self, - database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel], + database_strategy: DatabaseStrategy[AccessTokenModel], access_token_database: AccessTokenDatabaseMock, - user_manager, + user_manager: UserManager, + token_data: UserTokenData[UserModel, IDType], ): await access_token_database.create( { "token": "TOKEN", "user_id": uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f"), + "scopes": token_data.scope, + **token_data.dict(exclude={"user_id", "user", "scopes"}), } ) - authenticated_user = await database_strategy.read_token("TOKEN", user_manager) - assert authenticated_user is None + access_token = await database_strategy.read_token("TOKEN", user_manager) + assert access_token is None + + @pytest.mark.asyncio + @pytest.mark.parametrize("token_expired", [True]) + async def test_expired_token( + self, + database_strategy: DatabaseStrategy[AccessTokenModel], + access_token_database: AccessTokenDatabaseMock, + user_manager: UserManager, + token_data: UserTokenData[UserModel, IDType], + ): + await access_token_database.create( + { + "token": "TOKEN", + "user_id": token_data.user.id, + "scopes": token_data.scope, + **token_data.dict(exclude={"user_id", "user", "scopes"}), + } + ) + access_token = await database_strategy.read_token("TOKEN", user_manager) + assert access_token is None @pytest.mark.asyncio async def test_valid_token( self, - database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel], + database_strategy: DatabaseStrategy[AccessTokenModel], access_token_database: AccessTokenDatabaseMock, - user_manager, - user: UserModel, + user_manager: UserManager, + token_data: UserTokenData[UserModel, IDType], ): - await access_token_database.create({"token": "TOKEN", "user_id": user.id}) - authenticated_user = await database_strategy.read_token("TOKEN", user_manager) - assert authenticated_user is not None - assert authenticated_user.id == user.id + await access_token_database.create( + { + "token": "TOKEN", + "user_id": token_data.user.id, + "scopes": token_data.scope, + **token_data.dict(exclude={"user_id", "user", "scopes"}), + } + ) + access_token = await database_strategy.read_token("TOKEN", user_manager) + assert access_token is not None + assert access_token.dict() == token_data.dict() @pytest.mark.authentication @pytest.mark.asyncio async def test_write_token( - database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel], + database_strategy: DatabaseStrategy[AccessTokenModel], access_token_database: AccessTokenDatabaseMock, - user: UserModel, + token_data: UserTokenData[UserModel, IDType], ): - token = await database_strategy.write_token(user) + token = await database_strategy.write_token(token_data) access_token = await access_token_database.get_by_token(token) - assert access_token is not None - assert access_token.user_id == user.id + assert_valid_token_model(access_token, token_data) @pytest.mark.authentication @pytest.mark.asyncio async def test_destroy_token( - database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel], + database_strategy: DatabaseStrategy[AccessTokenModel], access_token_database: AccessTokenDatabaseMock, user: UserModel, ): diff --git a/tests/test_authentication_strategy_jwt.py b/tests/test_authentication_strategy_jwt.py index 45aea5a5..65e543d2 100644 --- a/tests/test_authentication_strategy_jwt.py +++ b/tests/test_authentication_strategy_jwt.py @@ -1,13 +1,18 @@ +from datetime import datetime +from typing import Any, Dict, Optional +from uuid import UUID + +import jwt import pytest +from pydantic import SecretStr from fastapi_users.authentication.strategy import ( JWTStrategy, StrategyDestroyNotSupportedError, ) +from fastapi_users.authentication.token import UserTokenData from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt -from tests.conftest import IDType, UserModel - -LIFETIME = 3600 +from tests.conftest import IDType, UserManager, UserModel ECC_PRIVATE_KEY = """-----BEGIN PRIVATE KEY----- MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgewlS46hocOLtT9Px @@ -59,32 +64,65 @@ -----END PUBLIC KEY-----""" +@pytest.fixture(params=["SECRET"]) +def secret(request: pytest.FixtureRequest) -> SecretType: + return request.param # type: ignore + + +@pytest.fixture(params=["HS256"]) +def algorithm(request: pytest.FixtureRequest) -> str: + return request.param # type: ignore + + +@pytest.fixture(params=[None]) +def public_key(request: pytest.FixtureRequest) -> Optional[str]: + return request.param # type: ignore + + @pytest.fixture -def jwt_strategy(request, secret: SecretType): - if request.param == "HS256": - return JWTStrategy(secret, LIFETIME) - elif request.param == "RS256": - return JWTStrategy( - RSA_PRIVATE_KEY, LIFETIME, algorithm="RS256", public_key=RSA_PUBLIC_KEY - ) - elif request.param == "ES256": - return JWTStrategy( - ECC_PRIVATE_KEY, LIFETIME, algorithm="ES256", public_key=ECC_PUBLIC_KEY - ) - raise ValueError(f"Unrecognized algorithm: {request.param}") +@pytest.mark.parametrize( + ("algorithm", "secret", "public_key"), + [ + ("HS256", "SECRET", None), + ("HS256", SecretStr("SECRET"), None), + ("RS256", RSA_PRIVATE_KEY, RSA_PUBLIC_KEY), + ("ES256", ECC_PRIVATE_KEY, ECC_PUBLIC_KEY), + ], +) +def jwt_strategy(algorithm: str, secret: SecretType, public_key: Optional[str]): + if algorithm == "HS256": + return JWTStrategy(secret) # use default values + else: + return JWTStrategy(secret, algorithm=algorithm, public_key=public_key) @pytest.fixture -def token(jwt_strategy: JWTStrategy[UserModel, IDType]): - def _token(user_id=None, lifetime=LIFETIME): - data = {"aud": "fastapi-users:auth"} - if user_id is not None: - data["user_id"] = str(user_id) - return generate_jwt( - data, jwt_strategy.encode_key, lifetime, algorithm=jwt_strategy.algorithm - ) +def user_id(request: pytest.FixtureRequest, user: UserModel) -> Optional[UUID]: + if hasattr(request, "param"): + return request.param # type: ignore + else: + return user.id + - return _token +@pytest.fixture +def jwt_token( + jwt_strategy: JWTStrategy, + token_data: UserTokenData[UserModel, IDType], + user_id: Optional[UUID], +) -> str: + data: Dict[str, Any] = {"aud": "fastapi-users:auth"} + + if user_id is not None: + data["sub"] = str(user_id) + + if token_data.expires_at: + data["exp"] = int(token_data.expires_at.timestamp()) + + data["iat"] = int(token_data.created_at.timestamp()) + data["auth_time"] = int(token_data.last_authenticated.timestamp()) + data["scope"] = token_data.scope + + return generate_jwt(data, jwt_strategy.encode_key, algorithm=jwt_strategy.algorithm) @pytest.mark.parametrize("jwt_strategy", ["HS256", "RS256", "ES256"], indirect=True) @@ -92,68 +130,131 @@ def _token(user_id=None, lifetime=LIFETIME): class TestReadToken: @pytest.mark.asyncio async def test_missing_token( - self, jwt_strategy: JWTStrategy[UserModel, IDType], user_manager + self, + jwt_strategy: JWTStrategy, + user_manager: UserManager, ): authenticated_user = await jwt_strategy.read_token(None, user_manager) assert authenticated_user is None @pytest.mark.asyncio async def test_invalid_token( - self, jwt_strategy: JWTStrategy[UserModel, IDType], user_manager + self, + jwt_strategy: JWTStrategy, + user_manager: UserManager, ): authenticated_user = await jwt_strategy.read_token("foo", user_manager) assert authenticated_user is None @pytest.mark.asyncio - async def test_valid_token_missing_user_payload( - self, jwt_strategy: JWTStrategy[UserModel, IDType], user_manager, token - ): - authenticated_user = await jwt_strategy.read_token(token(), user_manager) - assert authenticated_user is None - - @pytest.mark.asyncio - async def test_valid_token_invalid_uuid( - self, jwt_strategy: JWTStrategy[UserModel, IDType], user_manager, token - ): - authenticated_user = await jwt_strategy.read_token(token("foo"), user_manager) - assert authenticated_user is None - - @pytest.mark.asyncio - async def test_valid_token_not_existing_user( - self, jwt_strategy: JWTStrategy[UserModel, IDType], user_manager, token + @pytest.mark.parametrize( + "user_id", + [ + None, + "foo", + "d35d213e-f3d8-4f08-954a-7e0d1bea286f", # non-existent user + ], + indirect=True, + ) + async def test_valid_token_invalid_user( + self, + jwt_strategy: JWTStrategy, + user_manager: UserManager, + jwt_token: str, ): - authenticated_user = await jwt_strategy.read_token( - token("d35d213e-f3d8-4f08-954a-7e0d1bea286f"), user_manager - ) + authenticated_user = await jwt_strategy.read_token(jwt_token, user_manager) assert authenticated_user is None - @pytest.mark.asyncio - async def test_valid_token( - self, jwt_strategy: JWTStrategy[UserModel, IDType], user_manager, token, user - ): - authenticated_user = await jwt_strategy.read_token(token(user.id), user_manager) - assert authenticated_user is not None - assert authenticated_user.id == user.id - @pytest.mark.parametrize("jwt_strategy", ["HS256", "RS256", "ES256"], indirect=True) +@pytest.mark.parametrize("token_expired", [False], indirect=True) @pytest.mark.authentication @pytest.mark.asyncio -async def test_write_token(jwt_strategy: JWTStrategy[UserModel, IDType], user): - token = await jwt_strategy.write_token(user) - +async def test_write_token( + jwt_strategy: JWTStrategy, token_data: UserTokenData[UserModel, IDType] +): + token = await jwt_strategy.write_token(token_data) decoded = decode_jwt( token, jwt_strategy.decode_key, audience=jwt_strategy.token_audience, algorithms=[jwt_strategy.algorithm], ) - assert decoded["user_id"] == str(user.id) + assert decoded["sub"] == str(token_data.user.id) + assert decoded["iat"] == int(token_data.created_at.timestamp()) + assert decoded["scope"] == token_data.scope + assert decoded["auth_time"] == int(token_data.last_authenticated.timestamp()) + if token_data.expires_at: + assert "exp" in decoded + assert decoded["exp"] == int(token_data.expires_at.timestamp()) + + +@pytest.mark.parametrize("jwt_strategy", ["HS256", "RS256", "ES256"], indirect=True) +@pytest.mark.parametrize("token_expired", [True], indirect=True) +@pytest.mark.authentication +@pytest.mark.asyncio +async def test_write_token_expired( + jwt_strategy: JWTStrategy, token_data: UserTokenData[UserModel, IDType] +): + token = await jwt_strategy.write_token(token_data) + + with pytest.raises(jwt.exceptions.ExpiredSignatureError): + decode_jwt( + token, + jwt_strategy.decode_key, + audience=jwt_strategy.token_audience, + algorithms=[jwt_strategy.algorithm], + ) + + +def assert_token_data_approximately_equal( + left: UserTokenData[UserModel, IDType], right: UserTokenData[UserModel, IDType] +): + def assert_seconds_equal(left: datetime, right: datetime): + assert left.replace(microsecond=0) == right.replace(microsecond=0) + + assert left.user == right.user + assert left.scopes == right.scopes + if left.expires_at and right.expires_at: + assert_seconds_equal(left.expires_at, right.expires_at) + else: + assert left.expires_at == right.expires_at + assert_seconds_equal(left.created_at, right.created_at) + assert_seconds_equal(left.last_authenticated, right.last_authenticated) + + +@pytest.mark.parametrize("jwt_strategy", ["HS256", "RS256", "ES256"], indirect=True) +@pytest.mark.parametrize("token_expired", [False], indirect=True) +@pytest.mark.authentication +@pytest.mark.asyncio +async def test_read_token( + jwt_strategy: JWTStrategy, + token_data: UserTokenData[UserModel, IDType], + user_manager: UserManager, +): + token = await jwt_strategy.write_token(token_data) + decoded = await jwt_strategy.read_token(token, user_manager) + assert decoded is not None + assert_token_data_approximately_equal(decoded, token_data) + + +@pytest.mark.parametrize("jwt_strategy", ["HS256", "RS256", "ES256"], indirect=True) +@pytest.mark.parametrize("token_expired", [True], indirect=True) +@pytest.mark.authentication +@pytest.mark.asyncio +async def test_read_token_expired( + jwt_strategy: JWTStrategy, + token_data: UserTokenData[UserModel, IDType], + user_manager: UserManager, +): + token = await jwt_strategy.write_token(token_data) + decoded = await jwt_strategy.read_token(token, user_manager) + assert decoded == None @pytest.mark.parametrize("jwt_strategy", ["HS256", "RS256", "ES256"], indirect=True) @pytest.mark.authentication @pytest.mark.asyncio -async def test_destroy_token(jwt_strategy: JWTStrategy[UserModel, IDType], user): +async def test_destroy_token(jwt_strategy: JWTStrategy, user: UserModel): with pytest.raises(StrategyDestroyNotSupportedError): await jwt_strategy.destroy_token("TOKEN", user) diff --git a/tests/test_authentication_strategy_redis.py b/tests/test_authentication_strategy_redis.py index 7d178906..89babbe6 100644 --- a/tests/test_authentication_strategy_redis.py +++ b/tests/test_authentication_strategy_redis.py @@ -1,10 +1,12 @@ +import json from datetime import datetime from typing import Dict, Optional, Tuple import pytest from fastapi_users.authentication.strategy import RedisStrategy -from tests.conftest import IDType, UserModel +from fastapi_users.authentication.token import UserTokenData +from tests.conftest import IDType, UserManager, UserModel class RedisMock: @@ -42,29 +44,25 @@ def redis() -> RedisMock: @pytest.fixture def redis_strategy(redis): - return RedisStrategy(redis, 3600) + return RedisStrategy(redis) @pytest.mark.authentication class TestReadToken: @pytest.mark.asyncio - async def test_missing_token( - self, redis_strategy: RedisStrategy[UserModel, IDType], user_manager - ): + async def test_missing_token(self, redis_strategy: RedisStrategy, user_manager): authenticated_user = await redis_strategy.read_token(None, user_manager) assert authenticated_user is None @pytest.mark.asyncio - async def test_invalid_token( - self, redis_strategy: RedisStrategy[UserModel, IDType], user_manager - ): + async def test_invalid_token(self, redis_strategy: RedisStrategy, user_manager): authenticated_user = await redis_strategy.read_token("TOKEN", user_manager) assert authenticated_user is None @pytest.mark.asyncio async def test_valid_token_invalid_uuid( self, - redis_strategy: RedisStrategy[UserModel, IDType], + redis_strategy: RedisStrategy, redis: RedisMock, user_manager, ): @@ -75,7 +73,7 @@ async def test_valid_token_invalid_uuid( @pytest.mark.asyncio async def test_valid_token_not_existing_user( self, - redis_strategy: RedisStrategy[UserModel, IDType], + redis_strategy: RedisStrategy, redis: RedisMock, user_manager, ): @@ -85,35 +83,69 @@ async def test_valid_token_not_existing_user( authenticated_user = await redis_strategy.read_token("TOKEN", user_manager) assert authenticated_user is None - @pytest.mark.asyncio - async def test_valid_token( - self, - redis_strategy: RedisStrategy[UserModel, IDType], - redis: RedisMock, - user_manager, - user, - ): - await redis.set(f"{redis_strategy.key_prefix}TOKEN", str(user.id)) - authenticated_user = await redis_strategy.read_token("TOKEN", user_manager) - assert authenticated_user is not None - assert authenticated_user.id == user.id - +@pytest.mark.parametrize("token_expired", [False], indirect=True) @pytest.mark.authentication @pytest.mark.asyncio async def test_write_token( - redis_strategy: RedisStrategy[UserModel, IDType], redis: RedisMock, user + redis_strategy: RedisStrategy, + redis: RedisMock, + token_data: UserTokenData[UserModel, IDType], ): - token = await redis_strategy.write_token(user) + token = await redis_strategy.write_token(token_data) + + token_value = await redis.get(f"{redis_strategy.key_prefix}{token}") + assert token_value is not None + + decoded = json.loads(token_value) + + assert decoded["user_id"] == str(token_data.user.id) + assert set(decoded["scopes"]) == token_data.scopes + + assert datetime.fromisoformat(decoded["created_at"]) == token_data.created_at + + assert ( + datetime.fromisoformat(decoded["last_authenticated"]) + == token_data.last_authenticated + ) + + if token_data.expires_at: + assert "expires_at" in decoded + assert datetime.fromisoformat(decoded["expires_at"]) == token_data.expires_at + + +@pytest.mark.parametrize("token_expired", [True], indirect=True) +@pytest.mark.authentication +@pytest.mark.asyncio +async def test_write_token_expired( + redis_strategy: RedisStrategy, + redis: RedisMock, + token_data: UserTokenData[UserModel, IDType], +): + token = await redis_strategy.write_token(token_data) value = await redis.get(f"{redis_strategy.key_prefix}{token}") - assert value == str(user.id) + assert value is None + + +@pytest.mark.parametrize("token_expired", [False], indirect=True) +@pytest.mark.authentication +@pytest.mark.asyncio +async def test_read_token( + redis_strategy: RedisStrategy, + redis: RedisMock, + token_data: UserTokenData[UserModel, IDType], + user_manager: UserManager, +): + token = await redis_strategy.write_token(token_data) + retrieved_token_data = await redis_strategy.read_token(token, user_manager) + assert retrieved_token_data == token_data @pytest.mark.authentication @pytest.mark.asyncio async def test_destroy_token( - redis_strategy: RedisStrategy[UserModel, IDType], redis: RedisMock, user + redis_strategy: RedisStrategy, redis: RedisMock, user: UserModel ): await redis.set(f"{redis_strategy.key_prefix}TOKEN", str(user.id)) diff --git a/tests/test_authentication_token.py b/tests/test_authentication_token.py new file mode 100644 index 00000000..87cb1baa --- /dev/null +++ b/tests/test_authentication_token.py @@ -0,0 +1,52 @@ +import json +from datetime import datetime + +import pytest + +from fastapi_users.authentication.token import TokenData, UserTokenData +from tests.conftest import IDType, UserManager, UserModel + + +@pytest.mark.authentication +def test_token_to_json(token_data: UserTokenData[UserModel, IDType]): + + token = token_data.json() + + decoded = json.loads(token) + + assert decoded["user_id"] == str(token_data.user.id) + assert set(decoded["scopes"]) == token_data.scopes + + assert datetime.fromisoformat(decoded["created_at"]) == token_data.created_at + + assert ( + datetime.fromisoformat(decoded["last_authenticated"]) + == token_data.last_authenticated + ) + + if token_data.expires_at: + assert "expires_at" in decoded + assert datetime.fromisoformat(decoded["expires_at"]) == token_data.expires_at + + +@pytest.mark.authentication +@pytest.mark.asyncio +async def test_token_from_json( + token_data: UserTokenData[UserModel, IDType], user_manager: UserManager +): + + token_data_dict = { + "user_id": str(token_data.user.id), + "created_at": token_data.created_at.isoformat(), + "last_authenticated": token_data.last_authenticated.isoformat(), + "scopes": list(token_data.scopes), + } + if token_data.expires_at: + token_data_dict["expires_at"] = token_data.expires_at.isoformat() + + token = json.dumps(token_data_dict) + + parsed_token = TokenData.parse_raw(token) + parsed_user_token = await parsed_token.lookup_user(user_manager) + + assert parsed_user_token == token_data diff --git a/tests/test_authentication_transport_bearer.py b/tests/test_authentication_transport_bearer.py index cadf0457..e0b494a0 100644 --- a/tests/test_authentication_transport_bearer.py +++ b/tests/test_authentication_transport_bearer.py @@ -4,6 +4,7 @@ from fastapi_users.authentication.transport import ( BearerTransport, TransportLogoutNotSupportedError, + TransportTokenResponse, ) from fastapi_users.authentication.transport.bearer import BearerResponse @@ -17,11 +18,29 @@ def bearer_transport() -> BearerTransport: @pytest.mark.asyncio async def test_get_login_response(bearer_transport: BearerTransport): response = Response() - login_response = await bearer_transport.get_login_response("TOKEN", response) + token_response = TransportTokenResponse(access_token="TOKEN") + login_response = await bearer_transport.get_login_response(token_response, response) assert isinstance(login_response, BearerResponse) assert login_response.access_token == "TOKEN" + assert login_response.refresh_token == None + assert login_response.token_type == "bearer" + + +@pytest.mark.authentication +@pytest.mark.asyncio +async def test_get_login_response_with_refresh(bearer_transport: BearerTransport): + response = Response() + token_response = TransportTokenResponse( + access_token="TOKEN", refresh_token="REFRESH_TOKEN" + ) + login_response = await bearer_transport.get_login_response(token_response, response) + + assert isinstance(login_response, BearerResponse) + + assert login_response.access_token == "TOKEN" + assert login_response.refresh_token == "REFRESH_TOKEN" assert login_response.token_type == "bearer" diff --git a/tests/test_authentication_transport_cookie.py b/tests/test_authentication_transport_cookie.py index ca99f9ca..c5557966 100644 --- a/tests/test_authentication_transport_cookie.py +++ b/tests/test_authentication_transport_cookie.py @@ -3,7 +3,10 @@ import pytest from fastapi import Response, status -from fastapi_users.authentication.transport import CookieTransport +from fastapi_users.authentication.transport import ( + CookieTransport, + TransportTokenResponse, +) COOKIE_MAX_AGE = 3600 COOKIE_NAME = "COOKIE_NAME" @@ -39,7 +42,8 @@ async def test_get_login_response(cookie_transport: CookieTransport): httponly = cookie_transport.cookie_httponly response = Response() - login_response = await cookie_transport.get_login_response("TOKEN", response) + token_response = TransportTokenResponse(access_token="TOKEN") + login_response = await cookie_transport.get_login_response(token_response, response) assert login_response is None diff --git a/tests/test_fastapi_users.py b/tests/test_fastapi_users.py index 29109931..b378f3fa 100644 --- a/tests/test_fastapi_users.py +++ b/tests/test_fastapi_users.py @@ -5,7 +5,14 @@ from fastapi import Depends, FastAPI, status from fastapi_users import FastAPIUsers -from tests.conftest import IDType, User, UserCreate, UserModel, UserUpdate +from tests.conftest import ( + IDType, + User, + UserCreate, + UserModel, + UserUpdate, + mock_authorized_headers, +) @pytest.fixture @@ -167,7 +174,7 @@ async def test_valid_token( self, test_app_client: httpx.AsyncClient, user: UserModel ): response = await test_app_client.get( - "/current-user", headers={"Authorization": f"Bearer {user.id}"} + "/current-user", headers=mock_authorized_headers(user) ) assert response.status_code == status.HTTP_200_OK @@ -190,7 +197,7 @@ async def test_valid_token_inactive_user( ): response = await test_app_client.get( "/current-active-user", - headers={"Authorization": f"Bearer {inactive_user.id}"}, + headers=mock_authorized_headers(inactive_user), ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -198,7 +205,7 @@ async def test_valid_token( self, test_app_client: httpx.AsyncClient, user: UserModel ): response = await test_app_client.get( - "/current-active-user", headers={"Authorization": f"Bearer {user.id}"} + "/current-active-user", headers=mock_authorized_headers(user) ) assert response.status_code == status.HTTP_200_OK @@ -221,7 +228,7 @@ async def test_valid_token_unverified_user( ): response = await test_app_client.get( "/current-verified-user", - headers={"Authorization": f"Bearer {user.id}"}, + headers=mock_authorized_headers(user), ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -230,7 +237,7 @@ async def test_valid_token_verified_user( ): response = await test_app_client.get( "/current-verified-user", - headers={"Authorization": f"Bearer {verified_user.id}"}, + headers=mock_authorized_headers(verified_user), ) assert response.status_code == status.HTTP_200_OK @@ -252,7 +259,7 @@ async def test_valid_token_regular_user( self, test_app_client: httpx.AsyncClient, user: UserModel ): response = await test_app_client.get( - "/current-superuser", headers={"Authorization": f"Bearer {user.id}"} + "/current-superuser", headers=mock_authorized_headers(user) ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -260,7 +267,7 @@ async def test_valid_token_superuser( self, test_app_client: httpx.AsyncClient, superuser: UserModel ): response = await test_app_client.get( - "/current-superuser", headers={"Authorization": f"Bearer {superuser.id}"} + "/current-superuser", headers=mock_authorized_headers(superuser) ) assert response.status_code == status.HTTP_200_OK @@ -283,7 +290,7 @@ async def test_valid_token_regular_user( ): response = await test_app_client.get( "/current-verified-superuser", - headers={"Authorization": f"Bearer {user.id}"}, + headers=mock_authorized_headers(user), ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -292,7 +299,7 @@ async def test_valid_token_verified_user( ): response = await test_app_client.get( "/current-verified-superuser", - headers={"Authorization": f"Bearer {verified_user.id}"}, + headers=mock_authorized_headers(verified_user), ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -301,7 +308,7 @@ async def test_valid_token_superuser( ): response = await test_app_client.get( "/current-verified-superuser", - headers={"Authorization": f"Bearer {superuser.id}"}, + headers=mock_authorized_headers(superuser), ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -310,7 +317,7 @@ async def test_valid_token_verified_superuser( ): response = await test_app_client.get( "/current-verified-superuser", - headers={"Authorization": f"Bearer {verified_superuser.id}"}, + headers=mock_authorized_headers(verified_superuser), ) assert response.status_code == status.HTTP_200_OK @@ -334,7 +341,7 @@ async def test_valid_token( self, test_app_client: httpx.AsyncClient, user: UserModel ): response = await test_app_client.get( - "/optional-current-user", headers={"Authorization": f"Bearer {user.id}"} + "/optional-current-user", headers=mock_authorized_headers(user) ) assert response.status_code == status.HTTP_200_OK assert response.json() is not None @@ -360,7 +367,7 @@ async def test_valid_token_unverified_user( ): response = await test_app_client.get( "/optional-current-verified-user", - headers={"Authorization": f"Bearer {user.id}"}, + headers=mock_authorized_headers(user), ) assert response.status_code == status.HTTP_200_OK assert response.json() is None @@ -370,7 +377,7 @@ async def test_valid_token_verified_user( ): response = await test_app_client.get( "/optional-current-verified-user", - headers={"Authorization": f"Bearer {verified_user.id}"}, + headers=mock_authorized_headers(verified_user), ) assert response.status_code == status.HTTP_200_OK assert response.json() is not None @@ -396,7 +403,7 @@ async def test_valid_token_inactive_user( ): response = await test_app_client.get( "/optional-current-active-user", - headers={"Authorization": f"Bearer {inactive_user.id}"}, + headers=mock_authorized_headers(inactive_user), ) assert response.status_code == status.HTTP_200_OK assert response.json() is None @@ -406,7 +413,7 @@ async def test_valid_token( ): response = await test_app_client.get( "/optional-current-active-user", - headers={"Authorization": f"Bearer {user.id}"}, + headers=mock_authorized_headers(user), ) assert response.status_code == status.HTTP_200_OK assert response.json() is not None @@ -432,7 +439,7 @@ async def test_valid_token_regular_user( ): response = await test_app_client.get( "/optional-current-superuser", - headers={"Authorization": f"Bearer {user.id}"}, + headers=mock_authorized_headers(user), ) assert response.status_code == status.HTTP_200_OK assert response.json() is None @@ -442,7 +449,7 @@ async def test_valid_token_superuser( ): response = await test_app_client.get( "/optional-current-superuser", - headers={"Authorization": f"Bearer {superuser.id}"}, + headers=mock_authorized_headers(superuser), ) assert response.status_code == status.HTTP_200_OK assert response.json() is not None @@ -469,7 +476,7 @@ async def test_valid_token_regular_user( ): response = await test_app_client.get( "/optional-current-verified-superuser", - headers={"Authorization": f"Bearer {user.id}"}, + headers=mock_authorized_headers(user), ) assert response.status_code == status.HTTP_200_OK assert response.json() is None @@ -479,7 +486,7 @@ async def test_valid_token_verified_user( ): response = await test_app_client.get( "/optional-current-verified-superuser", - headers={"Authorization": f"Bearer {verified_user.id}"}, + headers=mock_authorized_headers(verified_user), ) assert response.status_code == status.HTTP_200_OK assert response.json() is None @@ -489,7 +496,7 @@ async def test_valid_token_superuser( ): response = await test_app_client.get( "/optional-current-verified-superuser", - headers={"Authorization": f"Bearer {superuser.id}"}, + headers=mock_authorized_headers(superuser), ) assert response.status_code == status.HTTP_200_OK assert response.json() is None @@ -499,7 +506,7 @@ async def test_valid_token_verified_superuser( ): response = await test_app_client.get( "/optional-current-verified-superuser", - headers={"Authorization": f"Bearer {verified_superuser.id}"}, + headers=mock_authorized_headers(verified_superuser), ) assert response.status_code == status.HTTP_200_OK assert response.json() is not None diff --git a/tests/test_manager.py b/tests/test_manager.py index bfafad0a..eb0f0169 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -17,6 +17,7 @@ ) from fastapi_users.jwt import decode_jwt, generate_jwt from fastapi_users.manager import IntegerIDMixin +from fastapi_users.scopes import SystemScope from tests.conftest import ( UserCreate, UserManagerMock, diff --git a/tests/test_router_auth.py b/tests/test_router_auth.py index 1bbc8026..ec98f481 100644 --- a/tests/test_router_auth.py +++ b/tests/test_router_auth.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncGenerator, Dict, Tuple, cast +from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, cast import httpx import pytest @@ -6,13 +6,23 @@ from fastapi_users.authentication import Authenticator from fastapi_users.router import ErrorCode, get_auth_router -from tests.conftest import UserModel, get_mock_authentication +from fastapi_users.scopes import SystemScope +from tests.conftest import ( + MockBackend, + UserModel, + assert_valid_token_response, + mock_authorized_headers, + mock_token_data, +) @pytest.fixture -def app_factory(get_user_manager, mock_authentication): +def app_factory( + mock_authentication_factory: Callable[[str], MockBackend], get_user_manager +): def _app_factory(requires_verification: bool) -> FastAPI: - mock_authentication_bis = get_mock_authentication(name="mock-bis") + mock_authentication = mock_authentication_factory("mock") + mock_authentication_bis = mock_authentication_factory("mock-bis") authenticator = Authenticator( [mock_authentication, mock_authentication_bis], get_user_manager ) @@ -113,12 +123,17 @@ async def test_wrong_password( @pytest.mark.parametrize( "email", ["king.arthur@camelot.bt", "King.Arthur@camelot.bt"] ) + @pytest.mark.parametrize( + "access_token_lifetime_seconds", [None, 3600], indirect=True + ) + @pytest.mark.freeze_time async def test_valid_credentials_unverified( self, path, email, test_app_client: Tuple[httpx.AsyncClient, bool], user: UserModel, + access_token_lifetime_seconds: Optional[int], ): client, requires_verification = test_app_client data = {"username": email, "password": "guinevere"} @@ -129,27 +144,36 @@ async def test_valid_credentials_unverified( assert data["detail"] == ErrorCode.LOGIN_USER_NOT_VERIFIED else: assert response.status_code == status.HTTP_200_OK - assert response.json() == { - "access_token": str(user.id), - "token_type": "bearer", - } + expected_access_token = mock_token_data( + user_id=user.id, + scopes={SystemScope.USER}, + lifetime_seconds=access_token_lifetime_seconds, + ) + assert_valid_token_response(response.json(), expected_access_token) @pytest.mark.parametrize("email", ["lake.lady@camelot.bt", "Lake.Lady@camelot.bt"]) + @pytest.mark.parametrize( + "access_token_lifetime_seconds", [None, 3600], indirect=True + ) + @pytest.mark.freeze_time async def test_valid_credentials_verified( self, path, email, test_app_client: Tuple[httpx.AsyncClient, bool], verified_user: UserModel, + access_token_lifetime_seconds: Optional[int], ): client, _ = test_app_client data = {"username": email, "password": "excalibur"} response = await client.post(path, data=data) assert response.status_code == status.HTTP_200_OK - assert response.json() == { - "access_token": str(verified_user.id), - "token_type": "bearer", - } + expected_access_token = mock_token_data( + user_id=verified_user.id, + scopes={SystemScope.USER, SystemScope.VERIFIED}, + lifetime_seconds=access_token_lifetime_seconds, + ) + assert_valid_token_response(response.json(), expected_access_token) async def test_inactive_user( self, @@ -185,9 +209,7 @@ async def test_valid_credentials_unverified( user: UserModel, ): client, requires_verification = test_app_client - response = await client.post( - path, headers={"Authorization": f"Bearer {user.id}"} - ) + response = await client.post(path, headers=mock_authorized_headers(user)) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN else: @@ -202,7 +224,7 @@ async def test_valid_credentials_verified( ): client, _ = test_app_client response = await client.post( - path, headers={"Authorization": f"Bearer {verified_user.id}"} + path, headers=mock_authorized_headers(verified_user) ) assert response.status_code == status.HTTP_200_OK diff --git a/tests/test_router_oauth.py b/tests/test_router_oauth.py index c13c164d..c71df67e 100644 --- a/tests/test_router_oauth.py +++ b/tests/test_router_oauth.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, cast +from typing import Any, Dict, Optional, cast import httpx import pytest @@ -13,12 +13,16 @@ get_oauth_associate_router, get_oauth_router, ) +from fastapi_users.scopes import SystemScope from tests.conftest import ( AsyncMethodMocker, User, UserManagerMock, UserModel, UserOAuthModel, + assert_valid_token_response, + mock_authorized_headers, + mock_token_data, ) @@ -188,6 +192,10 @@ async def test_already_exists_error( data = cast(Dict[str, Any], response.json()) assert data["detail"] == ErrorCode.OAUTH_USER_ALREADY_EXISTS + @pytest.mark.parametrize( + "access_token_lifetime_seconds", [None, 3600], indirect=True + ) + @pytest.mark.freeze_time async def test_active_user( self, async_method_mocker: AsyncMethodMocker, @@ -196,6 +204,7 @@ async def test_active_user( user_oauth: UserOAuthModel, user_manager_oauth: UserManagerMock, access_token: str, + access_token_lifetime_seconds: Optional[int], ): state_jwt = generate_state_token({}, "SECRET") async_method_mocker(oauth_client, "get_access_token", return_value=access_token) @@ -213,8 +222,12 @@ async def test_active_user( assert response.status_code == status.HTTP_200_OK - data = cast(Dict[str, Any], response.json()) - assert data["access_token"] == str(user_oauth.id) + expected_access_token = mock_token_data( + user_id=user_oauth.id, + scopes={SystemScope.USER}, + lifetime_seconds=access_token_lifetime_seconds, + ) + assert_valid_token_response(response.json(), expected_access_token) async def test_inactive_user( self, @@ -243,6 +256,10 @@ async def test_inactive_user( assert response.status_code == status.HTTP_400_BAD_REQUEST + @pytest.mark.parametrize( + "access_token_lifetime_seconds", [None, 3600], indirect=True + ) + @pytest.mark.freeze_time async def test_redirect_url_router( self, async_method_mocker: AsyncMethodMocker, @@ -251,6 +268,7 @@ async def test_redirect_url_router( user_oauth: UserOAuthModel, user_manager_oauth: UserManagerMock, access_token: str, + access_token_lifetime_seconds: Optional[int], ): state_jwt = generate_state_token({}, "SECRET") get_access_token_mock = async_method_mocker( @@ -273,9 +291,12 @@ async def test_redirect_url_router( get_access_token_mock.assert_called_once_with( "CODE", "http://www.tintagel.bt/callback", None ) - - data = cast(Dict[str, Any], response.json()) - assert data["access_token"] == str(user_oauth.id) + expected_access_token = mock_token_data( + user_id=user_oauth.id, + scopes={SystemScope.USER}, + lifetime_seconds=access_token_lifetime_seconds, + ) + assert_valid_token_response(response.json(), expected_access_token) @pytest.mark.router @@ -295,7 +316,7 @@ async def test_inactive_user( response = await test_app_client.get( "/oauth-associate/authorize", params={"scopes": ["scope1", "scope2"]}, - headers={"Authorization": f"Bearer {inactive_user_oauth.id}"}, + headers=mock_authorized_headers(inactive_user_oauth), ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -314,7 +335,7 @@ async def test_active_user( response = await test_app_client.get( "/oauth-associate/authorize", params={"scopes": ["scope1", "scope2"]}, - headers={"Authorization": f"Bearer {user_oauth.id}"}, + headers=mock_authorized_headers(user_oauth), ) assert response.status_code == status.HTTP_200_OK @@ -337,7 +358,7 @@ async def test_with_redirect_url( response = await test_app_client_redirect_url.get( "/oauth-associate/authorize", params={"scopes": ["scope1", "scope2"]}, - headers={"Authorization": f"Bearer {user_oauth.id}"}, + headers=mock_authorized_headers(user_oauth), ) assert response.status_code == status.HTTP_200_OK @@ -377,7 +398,7 @@ async def test_inactive_user( response = await test_app_client.get( "/oauth-associate/callback", params={"code": "CODE", "state": "STATE"}, - headers={"Authorization": f"Bearer {inactive_user_oauth.id}"}, + headers=mock_authorized_headers(inactive_user_oauth), ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -398,7 +419,7 @@ async def test_invalid_state( response = await test_app_client.get( "/oauth-associate/callback", params={"code": "CODE", "state": "STATE"}, - headers={"Authorization": f"Bearer {user_oauth.id}"}, + headers=mock_authorized_headers(user_oauth), ) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -422,7 +443,7 @@ async def test_state_with_different_user_id( response = await test_app_client.get( "/oauth-associate/callback", params={"code": "CODE", "state": state_jwt}, - headers={"Authorization": f"Bearer {user_oauth.id}"}, + headers=mock_authorized_headers(user_oauth), ) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -449,7 +470,7 @@ async def test_active_user( response = await test_app_client.get( "/oauth-associate/callback", params={"code": "CODE", "state": state_jwt}, - headers={"Authorization": f"Bearer {user_oauth.id}"}, + headers=mock_authorized_headers(user_oauth), ) assert response.status_code == status.HTTP_200_OK @@ -480,7 +501,7 @@ async def test_redirect_url_router( response = await test_app_client_redirect_url.get( "/oauth-associate/callback", params={"code": "CODE", "state": state_jwt}, - headers={"Authorization": f"Bearer {user_oauth.id}"}, + headers=mock_authorized_headers(user_oauth), ) assert response.status_code == status.HTTP_200_OK diff --git a/tests/test_router_refresh.py b/tests/test_router_refresh.py new file mode 100644 index 00000000..b274a408 --- /dev/null +++ b/tests/test_router_refresh.py @@ -0,0 +1,173 @@ +from datetime import datetime +from typing import AsyncGenerator, Callable, Optional, Tuple + +import httpx +import pytest +from fastapi import FastAPI, status + +from fastapi_users.authentication import Authenticator +from fastapi_users.router import get_auth_router, get_refresh_router +from fastapi_users.scopes import SystemScope +from tests.conftest import ( + MockBackend, + UserModel, + assert_valid_token_response, + mock_token_data, + mock_valid_access_token, + mock_valid_refresh_token, +) + + +@pytest.fixture +def app_factory( + mock_authentication_factory: Callable[[str], MockBackend], get_user_manager +): + def _app_factory(requires_verification: bool) -> FastAPI: + mock_authentication = mock_authentication_factory("mock") + mock_authentication_bis = mock_authentication_factory("mock-bis") + authenticator = Authenticator( + [mock_authentication, mock_authentication_bis], get_user_manager + ) + + mock_auth_router = get_auth_router( + mock_authentication, + get_user_manager, + authenticator, + requires_verification=requires_verification, + ) + mock_bis_auth_router = get_auth_router( + mock_authentication_bis, + get_user_manager, + authenticator, + requires_verification=requires_verification, + ) + + mock_refresh_router = get_refresh_router( + mock_authentication, + get_user_manager, + ) + mock_bis_refresh_router = get_refresh_router( + mock_authentication_bis, + get_user_manager, + ) + + app = FastAPI() + app.include_router(mock_auth_router, prefix="/mock") + app.include_router(mock_bis_auth_router, prefix="/mock-bis") + app.include_router(mock_refresh_router, prefix="/mock") + app.include_router(mock_bis_refresh_router, prefix="/mock-bis") + + return app + + return _app_factory + + +@pytest.fixture( + params=[True, False], ids=["required_verification", "not_required_verification"] +) +@pytest.mark.asyncio +async def test_app_client( + request, get_test_client, app_factory +) -> AsyncGenerator[Tuple[httpx.AsyncClient, bool], None]: + requires_verification = request.param + app = app_factory(requires_verification) + + async for client in get_test_client(app): + yield client, requires_verification + + +@pytest.mark.router +@pytest.mark.parametrize("refresh_token_enabled", [True], indirect=True) +@pytest.mark.parametrize("path", ["/mock", "/mock-bis"]) +@pytest.mark.asyncio +class TestRefresh: + async def test_malformed_token( + self, + path, + test_app_client: Tuple[httpx.AsyncClient, bool], + ): + client, _ = test_app_client + data = {"grant_type": "refresh_token", "refresh_token": "foo"} + response = await client.post(f"{path}/refresh", data=data) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + async def test_access_token_used_as_refresh_token( + self, + path, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_user: UserModel, + ): + client, _ = test_app_client + + data = { + "grant_type": "refresh_token", + "refresh_token": mock_valid_access_token(verified_user), + } + response = await client.post(f"{path}/refresh", data=data) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + @pytest.mark.parametrize( + "access_token_lifetime_seconds", [None, 3600], indirect=True + ) + @pytest.mark.parametrize( + "refresh_token_lifetime_seconds", [None, 86400], indirect=True + ) + @pytest.mark.freeze_time("2022-09-01 09:00") + async def test_valid_refresh_token( + self, + path, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_user: UserModel, + freezer, + access_token_lifetime_seconds: Optional[int], + refresh_token_lifetime_seconds: Optional[int], + ): + print(f"test_valid_refresh_token: {access_token_lifetime_seconds}") + client, _ = test_app_client + + data = { + "grant_type": "refresh_token", + "refresh_token": mock_valid_refresh_token(verified_user), + } + freezer.move_to("2022-09-01 10:00") + response = await client.post(f"{path}/refresh", data=data) + assert response.status_code == status.HTTP_200_OK + expected_access_token = mock_token_data( + user_id=verified_user.id, + scopes={SystemScope.USER, SystemScope.VERIFIED}, + lifetime_seconds=access_token_lifetime_seconds, + last_authenticated=datetime.fromisoformat("2022-09-01 09:00+00:00"), + ) + expected_refresh_token = mock_token_data( + user_id=verified_user.id, + scopes={SystemScope.REFRESH}, + lifetime_seconds=refresh_token_lifetime_seconds, + last_authenticated=datetime.fromisoformat("2022-09-01 09:00+00:00"), + ) + assert_valid_token_response( + response.json(), + expected_access_token, + expected_refresh_token, + ) + + +@pytest.mark.router +@pytest.mark.parametrize("refresh_token_enabled", [False], indirect=True) +@pytest.mark.parametrize("path", ["/mock", "/mock-bis"]) +@pytest.mark.asyncio +class TestMisconfiguredRefresh: + async def test_valid_refresh_token( + self, + path, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_user: UserModel, + ): + client, _ = test_app_client + + data = { + "grant_type": "refresh_token", + "refresh_token": mock_valid_refresh_token(verified_user), + } + response = await client.post(f"{path}/refresh", data=data) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json() == {"detail": "refresh_tokens_not_allowed"} diff --git a/tests/test_router_users.py b/tests/test_router_users.py index cc4364a9..5655681b 100644 --- a/tests/test_router_users.py +++ b/tests/test_router_users.py @@ -6,7 +6,13 @@ from fastapi_users.authentication import Authenticator from fastapi_users.router import ErrorCode, get_users_router -from tests.conftest import User, UserModel, UserUpdate, get_mock_authentication +from tests.conftest import ( + User, + UserModel, + UserUpdate, + get_mock_authentication, + mock_authorized_headers, +) @pytest.fixture @@ -62,7 +68,7 @@ async def test_inactive_user( ): client, _ = test_app_client response = await client.get( - "/me", headers={"Authorization": f"Bearer {inactive_user.id}"} + "/me", headers=mock_authorized_headers(inactive_user) ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -72,9 +78,7 @@ async def test_active_user( user: UserModel, ): client, requires_verification = test_app_client - response = await client.get( - "/me", headers={"Authorization": f"Bearer {user.id}"} - ) + response = await client.get("/me", headers=mock_authorized_headers(user)) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN else: @@ -90,7 +94,7 @@ async def test_verified_user( ): client, _ = test_app_client response = await client.get( - "/me", headers={"Authorization": f"Bearer {verified_user.id}"} + "/me", headers=mock_authorized_headers(verified_user) ) assert response.status_code == status.HTTP_200_OK data = cast(Dict[str, Any], response.json()) @@ -119,7 +123,7 @@ async def test_inactive_user( ): client, _ = test_app_client response = await client.patch( - "/me", headers={"Authorization": f"Bearer {inactive_user.id}"} + "/me", headers=mock_authorized_headers(inactive_user) ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -133,7 +137,7 @@ async def test_existing_email( response = await client.patch( "/me", json={"email": verified_user.email}, - headers={"Authorization": f"Bearer {user.id}"}, + headers=mock_authorized_headers(user), ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -151,7 +155,7 @@ async def test_invalid_password( response = await client.patch( "/me", json={"password": "m"}, - headers={"Authorization": f"Bearer {user.id}"}, + headers=mock_authorized_headers(user), ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -170,7 +174,7 @@ async def test_empty_body( ): client, requires_verification = test_app_client response = await client.patch( - "/me", json={}, headers={"Authorization": f"Bearer {user.id}"} + "/me", json={}, headers=mock_authorized_headers(user) ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -188,7 +192,7 @@ async def test_valid_body( client, requires_verification = test_app_client json = {"email": "king.arthur@tintagel.bt"} response = await client.patch( - "/me", json=json, headers={"Authorization": f"Bearer {user.id}"} + "/me", json=json, headers=mock_authorized_headers(user) ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -206,7 +210,7 @@ async def test_unverified_after_email_change( client, _ = test_app_client json = {"email": "king.arthur@tintagel.bt"} response = await client.patch( - "/me", json=json, headers={"Authorization": f"Bearer {verified_user.id}"} + "/me", json=json, headers=mock_authorized_headers(verified_user) ) assert response.status_code == status.HTTP_200_OK @@ -221,7 +225,7 @@ async def test_valid_body_is_superuser( client, requires_verification = test_app_client json = {"is_superuser": True} response = await client.patch( - "/me", json=json, headers={"Authorization": f"Bearer {user.id}"} + "/me", json=json, headers=mock_authorized_headers(user) ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -239,7 +243,7 @@ async def test_valid_body_is_active( client, requires_verification = test_app_client json = {"is_active": False} response = await client.patch( - "/me", json=json, headers={"Authorization": f"Bearer {user.id}"} + "/me", json=json, headers=mock_authorized_headers(user) ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -257,7 +261,7 @@ async def test_valid_body_is_verified( client, requires_verification = test_app_client json = {"is_verified": True} response = await client.patch( - "/me", json=json, headers={"Authorization": f"Bearer {user.id}"} + "/me", json=json, headers=mock_authorized_headers(user) ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -280,7 +284,7 @@ async def test_valid_body_password( json = {"password": "merlin"} response = await client.patch( - "/me", json=json, headers={"Authorization": f"Bearer {user.id}"} + "/me", json=json, headers=mock_authorized_headers(user) ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -298,7 +302,7 @@ async def test_empty_body_verified_user( ): client, _ = test_app_client response = await client.patch( - "/me", json={}, headers={"Authorization": f"Bearer {verified_user.id}"} + "/me", json={}, headers=mock_authorized_headers(verified_user) ) assert response.status_code == status.HTTP_200_OK @@ -313,7 +317,7 @@ async def test_valid_body_verified_user( client, _ = test_app_client json = {"email": "king.arthur@tintagel.bt"} response = await client.patch( - "/me", json=json, headers={"Authorization": f"Bearer {verified_user.id}"} + "/me", json=json, headers=mock_authorized_headers(verified_user) ) assert response.status_code == status.HTTP_200_OK @@ -328,7 +332,7 @@ async def test_valid_body_is_superuser_verified_user( client, _ = test_app_client json = {"is_superuser": True} response = await client.patch( - "/me", json=json, headers={"Authorization": f"Bearer {verified_user.id}"} + "/me", json=json, headers=mock_authorized_headers(verified_user) ) assert response.status_code == status.HTTP_200_OK @@ -343,7 +347,7 @@ async def test_valid_body_is_active_verified_user( client, _ = test_app_client json = {"is_active": False} response = await client.patch( - "/me", json=json, headers={"Authorization": f"Bearer {verified_user.id}"} + "/me", json=json, headers=mock_authorized_headers(verified_user) ) assert response.status_code == status.HTTP_200_OK @@ -358,7 +362,7 @@ async def test_valid_body_is_verified_verified_user( client, _ = test_app_client json = {"is_verified": False} response = await client.patch( - "/me", json=json, headers={"Authorization": f"Bearer {verified_user.id}"} + "/me", json=json, headers=mock_authorized_headers(verified_user) ) assert response.status_code == status.HTTP_200_OK @@ -378,7 +382,7 @@ async def test_valid_body_password_verified_user( json = {"password": "merlin"} response = await client.patch( - "/me", json=json, headers={"Authorization": f"Bearer {verified_user.id}"} + "/me", json=json, headers=mock_authorized_headers(verified_user) ) assert response.status_code == status.HTTP_200_OK assert mock_user_db.update.called is True @@ -403,7 +407,7 @@ async def test_regular_user( client, requires_verification = test_app_client response = await client.get( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", - headers={"Authorization": f"Bearer {user.id}"}, + headers=mock_authorized_headers(user), ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -416,7 +420,7 @@ async def test_verified_user( client, _ = test_app_client response = await client.get( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", - headers={"Authorization": f"Bearer {verified_user.id}"}, + headers=mock_authorized_headers(verified_user), ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -428,7 +432,7 @@ async def test_not_existing_user_unverified_superuser( client, requires_verification = test_app_client response = await client.get( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", - headers={"Authorization": f"Bearer {superuser.id}"}, + headers=mock_authorized_headers(superuser), ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -443,7 +447,7 @@ async def test_not_existing_user_verified_superuser( client, _ = test_app_client response = await client.get( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", - headers={"Authorization": f"Bearer {verified_superuser.id}"}, + headers=mock_authorized_headers(verified_superuser), ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -455,7 +459,7 @@ async def test_superuser( ): client, requires_verification = test_app_client response = await client.get( - f"/{user.id}", headers={"Authorization": f"Bearer {superuser.id}"} + f"/{user.id}", headers=mock_authorized_headers(superuser) ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -474,7 +478,7 @@ async def test_verified_superuser( ): client, _ = test_app_client response = await client.get( - f"/{user.id}", headers={"Authorization": f"Bearer {verified_superuser.id}"} + f"/{user.id}", headers=mock_authorized_headers(verified_superuser) ) assert response.status_code == status.HTTP_200_OK @@ -502,7 +506,7 @@ async def test_regular_user( client, requires_verification = test_app_client response = await client.patch( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", - headers={"Authorization": f"Bearer {user.id}"}, + headers=mock_authorized_headers(user), ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -515,7 +519,7 @@ async def test_verified_user( client, _ = test_app_client response = await client.patch( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", - headers={"Authorization": f"Bearer {verified_user.id}"}, + headers=mock_authorized_headers(verified_user), ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -528,7 +532,7 @@ async def test_not_existing_user_unverified_superuser( response = await client.patch( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", json={}, - headers={"Authorization": f"Bearer {superuser.id}"}, + headers=mock_authorized_headers(superuser), ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -544,7 +548,7 @@ async def test_not_existing_user_verified_superuser( response = await client.patch( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", json={}, - headers={"Authorization": f"Bearer {verified_superuser.id}"}, + headers=mock_authorized_headers(verified_superuser), ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -556,7 +560,7 @@ async def test_empty_body_unverified_superuser( ): client, requires_verification = test_app_client response = await client.patch( - f"/{user.id}", json={}, headers={"Authorization": f"Bearer {superuser.id}"} + f"/{user.id}", json={}, headers=mock_authorized_headers(superuser) ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -576,7 +580,7 @@ async def test_empty_body_verified_superuser( response = await client.patch( f"/{user.id}", json={}, - headers={"Authorization": f"Bearer {verified_superuser.id}"}, + headers=mock_authorized_headers(verified_superuser), ) assert response.status_code == status.HTTP_200_OK @@ -594,7 +598,7 @@ async def test_valid_body_unverified_superuser( response = await client.patch( f"/{user.id}", json=json, - headers={"Authorization": f"Bearer {superuser.id}"}, + headers=mock_authorized_headers(superuser), ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -615,7 +619,7 @@ async def test_existing_email_verified_superuser( response = await client.patch( f"/{user.id}", json={"email": verified_user.email}, - headers={"Authorization": f"Bearer {verified_superuser.id}"}, + headers=mock_authorized_headers(verified_superuser), ) assert response.status_code == status.HTTP_400_BAD_REQUEST data = cast(Dict[str, Any], response.json()) @@ -631,7 +635,7 @@ async def test_invalid_password_verified_superuser( response = await client.patch( f"/{user.id}", json={"password": "m"}, - headers={"Authorization": f"Bearer {verified_superuser.id}"}, + headers=mock_authorized_headers(verified_superuser), ) assert response.status_code == status.HTTP_400_BAD_REQUEST data = cast(Dict[str, Any], response.json()) @@ -651,7 +655,7 @@ async def test_valid_body_verified_superuser( response = await client.patch( f"/{user.id}", json=json, - headers={"Authorization": f"Bearer {verified_superuser.id}"}, + headers=mock_authorized_headers(verified_superuser), ) assert response.status_code == status.HTTP_200_OK @@ -669,7 +673,7 @@ async def test_valid_body_is_superuser_unverified_superuser( response = await client.patch( f"/{user.id}", json=json, - headers={"Authorization": f"Bearer {superuser.id}"}, + headers=mock_authorized_headers(superuser), ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -690,7 +694,7 @@ async def test_valid_body_is_superuser_verified_superuser( response = await client.patch( f"/{user.id}", json=json, - headers={"Authorization": f"Bearer {verified_superuser.id}"}, + headers=mock_authorized_headers(verified_superuser), ) assert response.status_code == status.HTTP_200_OK @@ -708,7 +712,7 @@ async def test_valid_body_is_active_unverified_superuser( response = await client.patch( f"/{user.id}", json=json, - headers={"Authorization": f"Bearer {superuser.id}"}, + headers=mock_authorized_headers(superuser), ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -729,7 +733,7 @@ async def test_valid_body_is_active_verified_superuser( response = await client.patch( f"/{user.id}", json=json, - headers={"Authorization": f"Bearer {verified_superuser.id}"}, + headers=mock_authorized_headers(verified_superuser), ) assert response.status_code == status.HTTP_200_OK @@ -747,7 +751,7 @@ async def test_valid_body_is_verified_unverified_superuser( response = await client.patch( f"/{user.id}", json=json, - headers={"Authorization": f"Bearer {superuser.id}"}, + headers=mock_authorized_headers(superuser), ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -768,7 +772,7 @@ async def test_valid_body_is_verified_verified_superuser( response = await client.patch( f"/{user.id}", json=json, - headers={"Authorization": f"Bearer {verified_superuser.id}"}, + headers=mock_authorized_headers(verified_superuser), ) assert response.status_code == status.HTTP_200_OK @@ -791,7 +795,7 @@ async def test_valid_body_password_unverified_superuser( response = await client.patch( f"/{user.id}", json=json, - headers={"Authorization": f"Bearer {superuser.id}"}, + headers=mock_authorized_headers(superuser), ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -818,7 +822,7 @@ async def test_valid_body_password_verified_superuser( response = await client.patch( f"/{user.id}", json=json, - headers={"Authorization": f"Bearer {verified_superuser.id}"}, + headers=mock_authorized_headers(verified_superuser), ) assert response.status_code == status.HTTP_200_OK assert mock_user_db.update.called is True @@ -843,7 +847,7 @@ async def test_regular_user( client, requires_verification = test_app_client response = await client.delete( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", - headers={"Authorization": f"Bearer {user.id}"}, + headers=mock_authorized_headers(user), ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -856,7 +860,7 @@ async def test_verified_user( client, _ = test_app_client response = await client.delete( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", - headers={"Authorization": f"Bearer {verified_user.id}"}, + headers=mock_authorized_headers(verified_user), ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -868,7 +872,7 @@ async def test_not_existing_user_unverified_superuser( client, requires_verification = test_app_client response = await client.delete( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", - headers={"Authorization": f"Bearer {superuser.id}"}, + headers=mock_authorized_headers(superuser), ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -883,7 +887,7 @@ async def test_not_existing_user_verified_superuser( client, _ = test_app_client response = await client.delete( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", - headers={"Authorization": f"Bearer {verified_superuser.id}"}, + headers=mock_authorized_headers(verified_superuser), ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -899,7 +903,7 @@ async def test_unverified_superuser( mocker.spy(mock_user_db, "delete") response = await client.delete( - f"/{user.id}", headers={"Authorization": f"Bearer {superuser.id}"} + f"/{user.id}", headers=mock_authorized_headers(superuser) ) if requires_verification: assert response.status_code == status.HTTP_403_FORBIDDEN @@ -923,7 +927,7 @@ async def test_verified_superuser( mocker.spy(mock_user_db, "delete") response = await client.delete( - f"/{user.id}", headers={"Authorization": f"Bearer {verified_superuser.id}"} + f"/{user.id}", headers=mock_authorized_headers(verified_superuser) ) assert response.status_code == status.HTTP_204_NO_CONTENT assert response.content == b""