Skip to content

Commit

Permalink
Update create_token function to return token and payload
Browse files Browse the repository at this point in the history
  • Loading branch information
Dennis Lee committed Mar 6, 2024
1 parent ca54251 commit 4d68012
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 86 deletions.
107 changes: 71 additions & 36 deletions app/auth/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from app.users.models import User
from app.config import settings
from app.auth import session_store
from app.token import token_store
from app.utils import get_current_epoch
from .schemas import TokensResponse, SessionInfo
from .utils import authenticate_user, hash_password, create_token
from .validator import validate_access_token
from app.token import get_token_payload

router = r = APIRouter()

Expand Down Expand Up @@ -85,11 +86,12 @@ async def login(request: Request, form_data: Annotated[OAuth2PasswordRequestForm
detail="Inactive user"
)

# if not user.is_verified:
# raise HTTPException(
# status_code=status.HTTP_403_FORBIDDEN,
# detail="Unverified user"
# )
if not user.is_verified:
# raise HTTPException(
# status_code=status.HTTP_403_FORBIDDEN,
# detail="Unverified user"
# )
print("DEBUG: Unverified user")

# generate new session id
session_id = str(uuid.uuid4())
Expand All @@ -104,27 +106,83 @@ async def login(request: Request, form_data: Annotated[OAuth2PasswordRequestForm
access_token = create_token(
data={"sub": user_id, "email": user.email, "sid": session_id},
expires_delta=access_token_expires_delta)
print(f"access_token: {access_token}")
print(f"access_token: {access_token['token']}, exp: {access_token['payload']['exp']}")

# create refresh token
refresh_token = create_token(
data={"sub": user_id, "sid": session_id},
expires_delta=refresh_token_expires_delta)
print(f"refresh_token: {refresh_token}")
print(f"refresh_token: {refresh_token['token']}, exp: {refresh_token['payload']['exp']}")

# Obtain user browser information
user_agent = str(request.headers["User-Agent"])
user_host = request.client.host

# create new session and add to active session cache
await create_session(user.id, session_id, user_agent, user_host, refresh_token_expires_delta.seconds)
# create new session and add to session store
await _add_session_to_store(user_id, session_id, user_agent, user_host, refresh_token_expires_delta.seconds)

# add token IDs to token store
await _add_tokens_to_store(access_token_id=access_token['payload']['jti'],
access_token_ttl=access_token_expires_delta.seconds,
refresh_token_id=refresh_token['payload']['jti'],
refresh_token_ttl=refresh_token_expires_delta.seconds)

# Return the access token, refresh token, and token type
return TokensResponse(access_token=access_token, refresh_token=refresh_token)
return TokensResponse(access_token=access_token['token'], refresh_token=refresh_token['token'])

@r.post("/logout")
async def logout(token_payload: Annotated[dict, Depends(get_token_payload)]):
"""
Logs out the user by revoking the user session.
\f
Args:
token_payload (dict): The access token payload.
Returns:
dict: A dictionary containing the message "Successfully logged out".
"""
user_id: str = token_payload.get("sub")
session_id: str = token_payload.get("sid")
token_id: str = token_payload.get("jti")

async def create_session(user_id: str, session_id: str, user_agent: str, user_host: str, ttl: int) -> bool:
# revoke session
await session_store.remove(user_id, session_id)

# revoke tokens
sibling_token_id = await token_store.retrieve(token_id)

await token_store.remove(token_id)
if sibling_token_id:
await token_store.remove(sibling_token_id)

return {"message": "Successfully logged out"}

async def _add_tokens_to_store(
access_token_id: str,
access_token_ttl: int,
refresh_token_id: str,
refresh_token_ttl: int):
"""
Create a new session for the user.
Adds access and refresh tokens to the token store.
Args:
access_token_id (str): The ID of the access token.
access_token_ttl (int): The time-to-live (TTL) of the access token in seconds.
refresh_token_id (str): The ID of the refresh token.
refresh_token_ttl (int): The time-to-live (TTL) of the refresh token in seconds.
Returns:
None
"""
print(f"Adding access token '{access_token_id}' to token store")
await token_store.add(access_token_id, refresh_token_id, access_token_ttl)

print(f"Adding refresh token '{refresh_token_id}' to token store")
await token_store.add(refresh_token_id, access_token_id, refresh_token_ttl)

async def _add_session_to_store(user_id: str, session_id: str, user_agent: str, user_host: str, ttl: int) -> bool:
"""
Create a new user session and add to the session store.
Args:
user_id (str): The user ID.
Expand All @@ -142,27 +200,4 @@ async def create_session(user_id: str, session_id: str, user_agent: str, user_ho
exp=get_current_epoch() + ttl)

# add session id to sessions cache, expiry time = refresh token expiry time
print(f"Adding '{user_id}:{session_id}' to active sessions cache")

return await session_store.add(user_id=user_id, session_id=session_id, value=session_info, ttl=ttl)

@r.post("/logout")
async def logout(token: str = Depends(oauth2_scheme)):
"""
Logs out the user by revoking the user session.
\f
Args:
token (str): The access token.
Returns:
dict: A dictionary containing the message "Successfully logged out".
"""
payload = await validate_access_token(token)

user_id: str = payload.get("sub")
session_id: str = payload.get("sid")

# revoke session
await session_store.remove(user_id, session_id)

return {"message": "Successfully logged out"}
18 changes: 14 additions & 4 deletions app/auth/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def authenticate_user(email: str, password: str, db: AsyncSession):
return user


def create_token(data: dict, expires_delta: timedelta) -> str:
def create_token(data: dict, expires_delta: timedelta) -> dict:
"""
Create a JSON Web Token (JWT) with the given dictionary data and expiration delta.
Expand All @@ -60,17 +60,27 @@ def create_token(data: dict, expires_delta: timedelta) -> str:
expires_delta (timedelta): The expiration time delta for the token.
Returns:
str: The encoded JWT.
dict: The token and its payload.
"""
# Copy the input data to avoid modifying the original dictionary
to_encode = data.copy()

# If an expiration delta is provided, calculate the expiration time by adding it to the current time
if expires_delta:
expire = datetime.utcnow() + expires_delta

# If no expiration delta is provided, use the default token expiration time
else:
expire = datetime.utcnow() + timedelta(minutes=DEFAULT_TOKEN_EXPIRE_MINUTES)

# Add a unique identifier to the token data
to_encode.update({"jti": str(uuid.uuid4())})
# Add the expiration time to the token data
to_encode.update({"exp": expire})
# Add the issued at time to the token data
to_encode.update({"iat": datetime.utcnow()})

return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return {
'token': jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM),
'payload': to_encode
}
68 changes: 22 additions & 46 deletions app/user.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,68 @@
from typing import Annotated
from uuid import UUID
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from pydantic import ValidationError
from sqlalchemy import select
from app.users.models import User
from app.database import get_db, AsyncSession
from .config import settings
from app.token import get_token_payload

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/login")

SECRET_KEY = settings.SECRET_KEY
ALGORITHM = settings.ALGORITHM

async def current_user(token: Annotated[str, Depends(oauth2_scheme)],
db: Annotated[AsyncSession, Depends(get_db)]) -> User:
async def current_user(token_payload: Annotated[dict, Depends(get_token_payload)],
db: Annotated[AsyncSession, Depends(get_db)]) -> User:
"""
Retrieve the current user based on the provided token.
Args:
token (str): The token to use to retrieve the user.
token_payload (dict): The token payload to use to retrieve the user.
db (AsyncSession): The database async session.
Returns:
User: The retrieved user object.
"""
return await get_current_user(token, db)
return await get_current_user(token_payload, db)

async def current_active_user(token: Annotated[str, Depends(oauth2_scheme)],
async def current_active_user(token_payload: Annotated[dict, Depends(get_token_payload)],
db: Annotated[AsyncSession, Depends(get_db)]) -> User:
"""
Retrieve the current active user based on the provided token.
Args:
token (str): The token to use to retrieve the user.
token_payload (dict): The token payload to use to retrieve the user.
db (AsyncSession): The database async session.
Returns:
User: The retrieved user object.
"""
return await get_current_user(token, db, active=True)
return await get_current_user(token_payload, db, active=True)

async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)],
async def get_current_user(token_payload: Annotated[dict, Depends(get_token_payload)],
db: Annotated[AsyncSession, Depends(get_db)],
active: bool = False,
verified: bool = False,
superuser: bool = False) -> User:
"""
Retrieve the current user based on the provided JWT token.
This function decodes the JWT token to get the user_id and then retrieves the user from the
database. It also checks if the user is active, verified, and a superuser based on the provided
parameters.
Retrieve the current user based on the provided token payload.
Args:
token (str): The JWT token of the user.
db (AsyncSession): The database async session.
active (bool, optional): Whether the user needs to be active. Defaults to False.
verified (bool, optional): Whether the user needs to be verified. Defaults to False.
superuser (bool, optional): Whether the user needs to be a superuser. Defaults to False.
token_payload (dict): The payload of the JWT token.
db (AsyncSession): The database session.
active (bool, optional): Filter users by active status. Defaults to False.
verified (bool, optional): Filter users by verified status. Defaults to False.
superuser (bool, optional): Filter users by superuser status. Defaults to False.
Returns:
User: The retrieved user object.
User: The current user.
Raises:
HTTPException: If the token is invalid, the user is not found, the user is inactive, the
user is unverified, or the user is not a superuser.
HTTPException: If the token is invalid or there is an error while decoding it.
"""
status_code = status.HTTP_401_UNAUTHORIZED

try:
# Set the options for the JWT token validation
options = {"require_exp": True, "require_sub": True, "require_jti": True}

# Decode, validate JWT token and get the payload
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM], options=options)

# Get the user_id from the payload
user_id: str = payload.get("sub")

# If the user_id is None, the token is invalid
if user_id is None:
raise HTTPException(status_code=status_code, detail="Invalid token")
# Get the user_id from the payload
user_id: str = token_payload.get("sub")

except (JWTError, ValidationError) as exc:
# If there is an error while decoding the token, it is invalid
raise HTTPException(status_code=status_code, detail=str(exc)) from exc
# If the user_id is None, the token is invalid
if user_id is None:
raise HTTPException(status_code=status_code, detail="Invalid token")

# Retrieve the user from the database
user = await get_user_by_id(user_id=user_id, db=db)
Expand Down

0 comments on commit 4d68012

Please sign in to comment.