diff --git a/app/auth/router.py b/app/auth/router.py index 3f7da33..a4a899a 100644 --- a/app/auth/router.py +++ b/app/auth/router.py @@ -9,10 +9,11 @@ from app.users.models import User from app.config import settings from app.auth import session_store +from app.token import token_store from app.utils import get_current_epoch from .schemas import TokensResponse, SessionInfo from .utils import authenticate_user, hash_password, create_token -from .validator import validate_access_token +from app.token import get_token_payload router = r = APIRouter() @@ -85,11 +86,12 @@ async def login(request: Request, form_data: Annotated[OAuth2PasswordRequestForm detail="Inactive user" ) - # if not user.is_verified: - # raise HTTPException( - # status_code=status.HTTP_403_FORBIDDEN, - # detail="Unverified user" - # ) + if not user.is_verified: + # raise HTTPException( + # status_code=status.HTTP_403_FORBIDDEN, + # detail="Unverified user" + # ) + print("DEBUG: Unverified user") # generate new session id session_id = str(uuid.uuid4()) @@ -104,27 +106,83 @@ async def login(request: Request, form_data: Annotated[OAuth2PasswordRequestForm access_token = create_token( data={"sub": user_id, "email": user.email, "sid": session_id}, expires_delta=access_token_expires_delta) - print(f"access_token: {access_token}") + print(f"access_token: {access_token['token']}, exp: {access_token['payload']['exp']}") # create refresh token refresh_token = create_token( data={"sub": user_id, "sid": session_id}, expires_delta=refresh_token_expires_delta) - print(f"refresh_token: {refresh_token}") + print(f"refresh_token: {refresh_token['token']}, exp: {refresh_token['payload']['exp']}") # Obtain user browser information user_agent = str(request.headers["User-Agent"]) user_host = request.client.host - # create new session and add to active session cache - await create_session(user.id, session_id, user_agent, user_host, refresh_token_expires_delta.seconds) + # create new session and add to session store + await _add_session_to_store(user_id, session_id, user_agent, user_host, refresh_token_expires_delta.seconds) + + # add token IDs to token store + await _add_tokens_to_store(access_token_id=access_token['payload']['jti'], + access_token_ttl=access_token_expires_delta.seconds, + refresh_token_id=refresh_token['payload']['jti'], + refresh_token_ttl=refresh_token_expires_delta.seconds) # Return the access token, refresh token, and token type - return TokensResponse(access_token=access_token, refresh_token=refresh_token) + return TokensResponse(access_token=access_token['token'], refresh_token=refresh_token['token']) + +@r.post("/logout") +async def logout(token_payload: Annotated[dict, Depends(get_token_payload)]): + """ + Logs out the user by revoking the user session. + \f + Args: + token_payload (dict): The access token payload. + + Returns: + dict: A dictionary containing the message "Successfully logged out". + """ + user_id: str = token_payload.get("sub") + session_id: str = token_payload.get("sid") + token_id: str = token_payload.get("jti") -async def create_session(user_id: str, session_id: str, user_agent: str, user_host: str, ttl: int) -> bool: + # revoke session + await session_store.remove(user_id, session_id) + + # revoke tokens + sibling_token_id = await token_store.retrieve(token_id) + + await token_store.remove(token_id) + if sibling_token_id: + await token_store.remove(sibling_token_id) + + return {"message": "Successfully logged out"} + +async def _add_tokens_to_store( + access_token_id: str, + access_token_ttl: int, + refresh_token_id: str, + refresh_token_ttl: int): """ - Create a new session for the user. + Adds access and refresh tokens to the token store. + + Args: + access_token_id (str): The ID of the access token. + access_token_ttl (int): The time-to-live (TTL) of the access token in seconds. + refresh_token_id (str): The ID of the refresh token. + refresh_token_ttl (int): The time-to-live (TTL) of the refresh token in seconds. + + Returns: + None + """ + print(f"Adding access token '{access_token_id}' to token store") + await token_store.add(access_token_id, refresh_token_id, access_token_ttl) + + print(f"Adding refresh token '{refresh_token_id}' to token store") + await token_store.add(refresh_token_id, access_token_id, refresh_token_ttl) + +async def _add_session_to_store(user_id: str, session_id: str, user_agent: str, user_host: str, ttl: int) -> bool: + """ + Create a new user session and add to the session store. Args: user_id (str): The user ID. @@ -142,27 +200,4 @@ async def create_session(user_id: str, session_id: str, user_agent: str, user_ho exp=get_current_epoch() + ttl) # add session id to sessions cache, expiry time = refresh token expiry time - print(f"Adding '{user_id}:{session_id}' to active sessions cache") - return await session_store.add(user_id=user_id, session_id=session_id, value=session_info, ttl=ttl) - -@r.post("/logout") -async def logout(token: str = Depends(oauth2_scheme)): - """ - Logs out the user by revoking the user session. - \f - Args: - token (str): The access token. - - Returns: - dict: A dictionary containing the message "Successfully logged out". - """ - payload = await validate_access_token(token) - - user_id: str = payload.get("sub") - session_id: str = payload.get("sid") - - # revoke session - await session_store.remove(user_id, session_id) - - return {"message": "Successfully logged out"} diff --git a/app/auth/utils.py b/app/auth/utils.py index 7a55002..57bb6bd 100644 --- a/app/auth/utils.py +++ b/app/auth/utils.py @@ -51,7 +51,7 @@ async def authenticate_user(email: str, password: str, db: AsyncSession): return user -def create_token(data: dict, expires_delta: timedelta) -> str: +def create_token(data: dict, expires_delta: timedelta) -> dict: """ Create a JSON Web Token (JWT) with the given dictionary data and expiration delta. @@ -60,17 +60,27 @@ def create_token(data: dict, expires_delta: timedelta) -> str: expires_delta (timedelta): The expiration time delta for the token. Returns: - str: The encoded JWT. - + dict: The token and its payload. """ + # Copy the input data to avoid modifying the original dictionary to_encode = data.copy() + + # If an expiration delta is provided, calculate the expiration time by adding it to the current time if expires_delta: expire = datetime.utcnow() + expires_delta + + # If no expiration delta is provided, use the default token expiration time else: expire = datetime.utcnow() + timedelta(minutes=DEFAULT_TOKEN_EXPIRE_MINUTES) + # Add a unique identifier to the token data to_encode.update({"jti": str(uuid.uuid4())}) + # Add the expiration time to the token data to_encode.update({"exp": expire}) + # Add the issued at time to the token data to_encode.update({"iat": datetime.utcnow()}) - return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return { + 'token': jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM), + 'payload': to_encode + } \ No newline at end of file diff --git a/app/user.py b/app/user.py index 027638c..1812ea3 100644 --- a/app/user.py +++ b/app/user.py @@ -1,92 +1,68 @@ from typing import Annotated from uuid import UUID from fastapi import Depends, HTTPException, status -from fastapi.security import OAuth2PasswordBearer -from jose import JWTError, jwt -from pydantic import ValidationError from sqlalchemy import select from app.users.models import User from app.database import get_db, AsyncSession -from .config import settings +from app.token import get_token_payload -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login") - -SECRET_KEY = settings.SECRET_KEY -ALGORITHM = settings.ALGORITHM - -async def current_user(token: Annotated[str, Depends(oauth2_scheme)], - db: Annotated[AsyncSession, Depends(get_db)]) -> User: +async def current_user(token_payload: Annotated[dict, Depends(get_token_payload)], + db: Annotated[AsyncSession, Depends(get_db)]) -> User: """ Retrieve the current user based on the provided token. Args: - token (str): The token to use to retrieve the user. + token_payload (dict): The token payload to use to retrieve the user. db (AsyncSession): The database async session. Returns: User: The retrieved user object. """ - return await get_current_user(token, db) + return await get_current_user(token_payload, db) -async def current_active_user(token: Annotated[str, Depends(oauth2_scheme)], +async def current_active_user(token_payload: Annotated[dict, Depends(get_token_payload)], db: Annotated[AsyncSession, Depends(get_db)]) -> User: """ Retrieve the current active user based on the provided token. Args: - token (str): The token to use to retrieve the user. + token_payload (dict): The token payload to use to retrieve the user. db (AsyncSession): The database async session. Returns: User: The retrieved user object. """ - return await get_current_user(token, db, active=True) + return await get_current_user(token_payload, db, active=True) -async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], +async def get_current_user(token_payload: Annotated[dict, Depends(get_token_payload)], db: Annotated[AsyncSession, Depends(get_db)], active: bool = False, verified: bool = False, superuser: bool = False) -> User: """ - Retrieve the current user based on the provided JWT token. - - This function decodes the JWT token to get the user_id and then retrieves the user from the - database. It also checks if the user is active, verified, and a superuser based on the provided - parameters. + Retrieve the current user based on the provided token payload. Args: - token (str): The JWT token of the user. - db (AsyncSession): The database async session. - active (bool, optional): Whether the user needs to be active. Defaults to False. - verified (bool, optional): Whether the user needs to be verified. Defaults to False. - superuser (bool, optional): Whether the user needs to be a superuser. Defaults to False. + token_payload (dict): The payload of the JWT token. + db (AsyncSession): The database session. + active (bool, optional): Filter users by active status. Defaults to False. + verified (bool, optional): Filter users by verified status. Defaults to False. + superuser (bool, optional): Filter users by superuser status. Defaults to False. Returns: - User: The retrieved user object. + User: The current user. Raises: - HTTPException: If the token is invalid, the user is not found, the user is inactive, the - user is unverified, or the user is not a superuser. + HTTPException: If the token is invalid or there is an error while decoding it. """ status_code = status.HTTP_401_UNAUTHORIZED - try: - # Set the options for the JWT token validation - options = {"require_exp": True, "require_sub": True, "require_jti": True} - - # Decode, validate JWT token and get the payload - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM], options=options) - - # Get the user_id from the payload - user_id: str = payload.get("sub") - - # If the user_id is None, the token is invalid - if user_id is None: - raise HTTPException(status_code=status_code, detail="Invalid token") + # Get the user_id from the payload + user_id: str = token_payload.get("sub") - except (JWTError, ValidationError) as exc: - # If there is an error while decoding the token, it is invalid - raise HTTPException(status_code=status_code, detail=str(exc)) from exc + # If the user_id is None, the token is invalid + if user_id is None: + raise HTTPException(status_code=status_code, detail="Invalid token") # Retrieve the user from the database user = await get_user_by_id(user_id=user_id, db=db)